Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.paimon.disk.IOManager;
import org.apache.paimon.io.DataFileMeta;
import org.apache.paimon.manifest.PartitionEntry;
import org.apache.paimon.partition.PartitionPredicate;
import org.apache.paimon.predicate.Predicate;
import org.apache.paimon.predicate.PredicateBuilder;
import org.apache.paimon.reader.RecordReader;
Expand All @@ -39,7 +40,9 @@
import org.apache.paimon.table.source.TableRead;
import org.apache.paimon.types.RowType;
import org.apache.paimon.utils.ChainTableUtils;
import org.apache.paimon.utils.Filter;
import org.apache.paimon.utils.InternalRowPartitionComputer;
import org.apache.paimon.utils.Pair;
import org.apache.paimon.utils.RowDataToObjectArrayConverter;

import java.io.IOException;
Expand All @@ -50,6 +53,7 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -130,6 +134,9 @@ public static class ChainTableBatchScan extends FallbackReadScan {
private final CoreOptions options;
private final RecordComparator partitionComparator;
private final ChainGroupReadTable chainGroupReadTable;
private PartitionPredicate partitionPredicate;
private Predicate dataPredicate;
private Filter<Integer> bucketFilter;

public ChainTableBatchScan(
DataTableScan mainScan,
Expand All @@ -153,10 +160,98 @@ public ChainTableBatchScan(
tableSchema.logicalPartitionType().getFieldTypes());
}

@Override
public ChainTableBatchScan withFilter(Predicate predicate) {
super.withFilter(predicate);
if (predicate != null) {
Pair<Optional<PartitionPredicate>, List<Predicate>> pair =
PartitionPredicate.splitPartitionPredicatesAndDataPredicates(
predicate,
tableSchema.logicalRowType(),
tableSchema.partitionKeys());
setPartitionPredicate(pair.getLeft().orElse(null));
dataPredicate =
pair.getRight().isEmpty() ? null : PredicateBuilder.and(pair.getRight());
}
return this;
}

@Override
public ChainTableBatchScan withPartitionFilter(Map<String, String> partitionSpec) {
super.withPartitionFilter(partitionSpec);
if (partitionSpec != null) {
setPartitionPredicate(
PartitionPredicate.fromMap(
tableSchema.logicalPartitionType(),
partitionSpec,
options.partitionDefaultName()));
}
return this;
}

@Override
public ChainTableBatchScan withPartitionFilter(List<BinaryRow> partitions) {
super.withPartitionFilter(partitions);
if (partitions != null) {
setPartitionPredicate(
PartitionPredicate.fromMultiple(
tableSchema.logicalPartitionType(), partitions));
}
return this;
}

@Override
public ChainTableBatchScan withPartitionsFilter(List<Map<String, String>> partitions) {
super.withPartitionsFilter(partitions);
if (partitions != null) {
setPartitionPredicate(
PartitionPredicate.fromMaps(
tableSchema.logicalPartitionType(),
partitions,
options.partitionDefaultName()));
}
return this;
}

@Override
public ChainTableBatchScan withPartitionFilter(PartitionPredicate partitionPredicate) {
super.withPartitionFilter(partitionPredicate);
if (partitionPredicate != null) {
setPartitionPredicate(partitionPredicate);
}
return this;
}

@Override
public ChainTableBatchScan withPartitionFilter(Predicate partitionPredicate) {
super.withPartitionFilter(partitionPredicate);
if (partitionPredicate != null) {
setPartitionPredicate(
PartitionPredicate.fromPredicate(
tableSchema.logicalPartitionType(), partitionPredicate));
}
return this;
}

@Override
public ChainTableBatchScan withBucketFilter(Filter<Integer> bucketFilter) {
this.bucketFilter = bucketFilter;
super.withBucketFilter(bucketFilter);
return this;
}

/**
* Builds a plan for chain tables.
*
* <p>Partitions that exist in the snapshot branch (based on partition predicates only) are
* treated as complete and are read directly from snapshot, subject to row-level predicates.
* Partitions that exist only in the delta branch are planned as chain splits by pairing
* each delta partition with the latest snapshot partition at or before it (if any), so the
* reader sees a full partition view.
*/
@Override
public Plan plan() {
List<Split> splits = new ArrayList<>();
Set<BinaryRow> completePartitions = new HashSet<>();
PredicateBuilder builder = new PredicateBuilder(tableSchema.logicalPartitionType());
for (Split split : mainScan.plan().splits()) {
DataSplit dataSplit = (DataSplit) split;
Expand All @@ -170,30 +265,36 @@ public Plan plan() {
new ChainSplit(
dataSplit.partition(),
dataSplit.dataFiles(),
fileBucketPathMapping,
fileBranchMapping));
completePartitions.add(dataSplit.partition());
fileBranchMapping,
fileBucketPathMapping));
}
List<BinaryRow> remainingPartitions =
fallbackScan.listPartitions().stream()
.filter(p -> !completePartitions.contains(p))

Set<BinaryRow> snapshotPartitions =
new HashSet<>(
newPartitionListingScan(true, partitionPredicate).listPartitions());

DataTableScan deltaPartitionScan = newPartitionListingScan(false, partitionPredicate);
List<BinaryRow> deltaPartitions =
deltaPartitionScan.listPartitions().stream()
.filter(p -> !snapshotPartitions.contains(p))
.sorted(partitionComparator)
.collect(Collectors.toList());
if (!remainingPartitions.isEmpty()) {
fallbackScan.withPartitionFilter(remainingPartitions);
List<BinaryRow> deltaPartitions = fallbackScan.listPartitions();
deltaPartitions =
deltaPartitions.stream()
.sorted(partitionComparator)
.collect(Collectors.toList());

if (!deltaPartitions.isEmpty()) {
BinaryRow maxPartition = deltaPartitions.get(deltaPartitions.size() - 1);
Predicate snapshotPredicate =
ChainTableUtils.createTriangularPredicate(
maxPartition,
partitionConverter,
builder::equal,
builder::lessThan);
mainScan.withPartitionFilter(snapshotPredicate);
List<BinaryRow> candidateSnapshotPartitions = mainScan.listPartitions();
PartitionPredicate snapshotPartitionPredicate =
PartitionPredicate.fromPredicate(
tableSchema.logicalPartitionType(), snapshotPredicate);
DataTableScan snapshotPartitionsScan =
newPartitionListingScan(true, snapshotPartitionPredicate);
List<BinaryRow> candidateSnapshotPartitions =
snapshotPartitionsScan.listPartitions();
candidateSnapshotPartitions =
candidateSnapshotPartitions.stream()
.sorted(partitionComparator)
Expand All @@ -202,8 +303,8 @@ public Plan plan() {
ChainTableUtils.findFirstLatestPartitions(
deltaPartitions, candidateSnapshotPartitions, partitionComparator);
for (Map.Entry<BinaryRow, BinaryRow> partitionParis : partitionMapping.entrySet()) {
DataTableScan snapshotScan = chainGroupReadTable.newSnapshotScan();
DataTableScan deltaScan = chainGroupReadTable.newDeltaScan();
DataTableScan snapshotScan = newFilteredScan(true);
DataTableScan deltaScan = newFilteredScan(false);
if (partitionParis.getValue() == null) {
List<Predicate> predicates = new ArrayList<>();
predicates.add(
Expand Down Expand Up @@ -281,8 +382,8 @@ public Plan plan() {
.flatMap(
datsSplit -> datsSplit.dataFiles().stream())
.collect(Collectors.toList()),
fileBucketPathMapping,
fileBranchMapping);
fileBranchMapping,
fileBucketPathMapping);
splits.add(split);
}
}
Expand All @@ -292,7 +393,49 @@ public Plan plan() {

@Override
public List<PartitionEntry> listPartitionEntries() {
return super.listPartitionEntries();
DataTableScan snapshotScan = newPartitionListingScan(true, partitionPredicate);
DataTableScan deltaScan = newPartitionListingScan(false, partitionPredicate);
List<PartitionEntry> partitionEntries =
new ArrayList<>(snapshotScan.listPartitionEntries());
Set<BinaryRow> partitions =
partitionEntries.stream()
.map(PartitionEntry::partition)
.collect(Collectors.toSet());
List<PartitionEntry> fallBackPartitionEntries = deltaScan.listPartitionEntries();
fallBackPartitionEntries.stream()
.filter(e -> !partitions.contains(e.partition()))
.forEach(partitionEntries::add);
return partitionEntries;
}

private void setPartitionPredicate(PartitionPredicate predicate) {
this.partitionPredicate = predicate;
}

private DataTableScan newPartitionListingScan(
boolean snapshot, PartitionPredicate scanPartitionPredicate) {
DataTableScan scan =
snapshot
? chainGroupReadTable.newSnapshotScan()
: chainGroupReadTable.newDeltaScan();
if (scanPartitionPredicate != null) {
scan.withPartitionFilter(scanPartitionPredicate);
}
return scan;
}

private DataTableScan newFilteredScan(boolean snapshot) {
DataTableScan scan =
snapshot
? chainGroupReadTable.newSnapshotScan()
: chainGroupReadTable.newDeltaScan();
if (dataPredicate != null) {
scan.withFilter(dataPredicate);
}
if (bucketFilter != null) {
scan.withBucketFilter(bucketFilter);
}
return scan;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,6 @@ public static ChainSplit deserialize(DataInputView in) throws IOException {
}

return new ChainSplit(
logicalPartition, dataFiles, fileBucketPathMapping, fileBranchMapping);
logicalPartition, dataFiles, fileBranchMapping, fileBucketPathMapping);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public void testChainSplitSerde() throws IOException, ClassNotFoundException {
}
ChainSplit split =
new ChainSplit(
logicalPartition, dataFiles, fileBucketPathMapping, fileBranchMapping);
logicalPartition, dataFiles, fileBranchMapping, fileBucketPathMapping);
byte[] bytes = InstantiationUtil.serializeObject(split);
ChainSplit newSplit =
InstantiationUtil.deserializeObject(bytes, ChainSplit.class.getClassLoader());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ public void testChainTable(@TempDir java.nio.file.Path tempDir) throws IOExcepti
spark.sql(
"insert overwrite table `my_db1`.`chain_test` partition (dt = '20250811') values (2, 2, '1-1' ),(4, 1, '1' );");
spark.sql(
"insert overwrite table `my_db1`.`chain_test` partition (dt = '20250812') values (3, 2, '1-1' ),(4, 2, '1-1' );");
"insert overwrite table `my_db1`.`chain_test` partition (dt = '20250812') values (3, 2, '1-1' ),(4, 2, '1-1' ),(7, 1, 'd7' );");
spark.sql(
"insert overwrite table `my_db1`.`chain_test` partition (dt = '20250813') values (5, 1, '1' ),(6, 1, '1' );");
spark.sql(
Expand Down Expand Up @@ -202,6 +202,48 @@ public void testChainTable(@TempDir java.nio.file.Path tempDir) throws IOExcepti
"[3,1,1,20250811]",
"[4,1,1,20250811]");

/** Chain read with filter */
assertThat(
spark
.sql(
"SELECT * FROM `my_db1`.`chain_test` where dt = '20250811' and t1 = 1")
.collectAsList().stream()
.map(Row::toString)
.collect(Collectors.toList()))
.containsExactlyInAnyOrder("[1,2,1-1,20250811]");
assertThat(
spark
.sql(
"SELECT * FROM `my_db1`.`chain_test` where dt = '20250811' and t1 = 4")
.collectAsList().stream()
.map(Row::toString)
.collect(Collectors.toList()))
.containsExactlyInAnyOrder("[4,1,1,20250811]");
assertThat(
spark
.sql(
"SELECT * FROM `my_db1`.`chain_test` where dt = '20250811' and t1 = 7")
.collectAsList().stream()
.map(Row::toString)
.collect(Collectors.toList()))
.isEmpty();

assertThat(
spark
.sql(
"SELECT * FROM `my_db1`.`chain_test` where dt in ('20250811', '20250812') and t1 = 1")
.collectAsList().stream()
.map(Row::toString)
.collect(Collectors.toList()))
.containsExactlyInAnyOrder("[1,2,1-1,20250811]", "[1,2,1-1,20250812]");

/** Snapshot read with filter */
assertThat(
spark.sql(
"SELECT * FROM `my_db1`.`chain_test` where dt = '20250812' and t1 = 7")
.collectAsList())
.isEmpty();

/** Multi partition Read */
assertThat(
spark
Expand Down Expand Up @@ -246,7 +288,8 @@ public void testChainTable(@TempDir java.nio.file.Path tempDir) throws IOExcepti
"[2,2,1-1,20250811]",
"[4,1,1,20250811]",
"[3,2,1-1,20250812]",
"[4,2,1-1,20250812]");
"[4,2,1-1,20250812]",
"[7,1,d7,20250812]");

/** Hybrid read */
assertThat(
Expand Down Expand Up @@ -402,6 +445,16 @@ public void testHourlyChainTable(@TempDir java.nio.file.Path tempDir) throws IOE
"[3,1,1,20250810,23]",
"[4,1,1,20250810,23]");

/** Chain read with non-partition filter */
assertThat(
spark
.sql(
"SELECT * FROM `my_db1`.`chain_test` where dt = '20250810' and hour = '23' and t1 = 1")
.collectAsList().stream()
.map(Row::toString)
.collect(Collectors.toList()))
.containsExactlyInAnyOrder("[1,2,1-1,20250810,23]");

/** Multi partition Read */
assertThat(
spark
Expand Down
Loading