Skip to content

Commit 76e5c0d

Browse files
committed
refine
1 parent 0fb419e commit 76e5c0d

File tree

5 files changed

+111
-65
lines changed

5 files changed

+111
-65
lines changed

src/query/service/src/pipelines/processors/transforms/new_hash_join/hash_join_factory.rs

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@ use databend_common_expression::FunctionContext;
2626
use databend_common_expression::HashMethodKind;
2727
use databend_common_sql::plans::JoinType;
2828

29-
use crate::pipelines::processors::transforms::memory::outer_left_join::OuterLeftHashJoin;
30-
use crate::pipelines::processors::transforms::new_hash_join::common::CStyleCell;
31-
use crate::pipelines::processors::transforms::new_hash_join::grace::GraceHashJoinState;
29+
use super::common::CStyleCell;
30+
use super::grace::GraceHashJoinState;
31+
use super::memory::outer_left_join::OuterLeftHashJoin;
32+
use super::memory::NestedLoopJoin;
3233
use crate::pipelines::processors::transforms::BasicHashJoinState;
3334
use crate::pipelines::processors::transforms::GraceHashJoin;
3435
use crate::pipelines::processors::transforms::InnerHashJoin;
@@ -126,13 +127,29 @@ impl HashJoinFactory {
126127
}
127128

128129
match typ {
129-
JoinType::Inner => Ok(Box::new(InnerHashJoin::create(
130-
&self.ctx,
131-
self.function_ctx.clone(),
132-
self.hash_method.clone(),
133-
self.desc.clone(),
134-
self.create_basic_state(id)?,
135-
)?)),
130+
JoinType::Inner => {
131+
let state = self.create_basic_state(id)?;
132+
let nested_loop_desc = self
133+
.desc
134+
.create_nested_loop_desc(&settings, &self.function_ctx)?;
135+
136+
let inner = InnerHashJoin::create(
137+
&settings,
138+
self.function_ctx.clone(),
139+
self.hash_method.clone(),
140+
self.desc.clone(),
141+
state.clone(),
142+
nested_loop_desc
143+
.as_ref()
144+
.map(|desc| desc.nested_loop_join_threshold)
145+
.unwrap_or_default(),
146+
)?;
147+
148+
match nested_loop_desc {
149+
Some(desc) => Ok(Box::new(NestedLoopJoin::create(inner, state, desc))),
150+
None => Ok(Box::new(inner)),
151+
}
152+
}
136153
JoinType::Left => Ok(Box::new(OuterLeftHashJoin::create(
137154
&self.ctx,
138155
self.function_ctx.clone(),
@@ -148,11 +165,12 @@ impl HashJoinFactory {
148165
match typ {
149166
JoinType::Inner => {
150167
let inner_hash_join = InnerHashJoin::create(
151-
&self.ctx,
168+
&self.ctx.get_settings(),
152169
self.function_ctx.clone(),
153170
self.hash_method.clone(),
154171
self.desc.clone(),
155172
self.create_basic_state(id)?,
173+
0,
156174
)?;
157175

158176
Ok(Box::new(GraceHashJoin::create(

src/query/service/src/pipelines/processors/transforms/new_hash_join/memory/basic.rs

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -354,17 +354,10 @@ impl BasicHashJoin {
354354
let mut progress = ProgressValues::default();
355355
let mut plain = vec![];
356356
while let Some(chunk_index) = self.state.steal_chunk_index() {
357-
let chunk_mut = &mut self.state.chunks.as_mut()[chunk_index];
358-
359-
let mut chunk_block = DataBlock::empty();
360-
std::mem::swap(chunk_mut, &mut chunk_block);
361-
357+
let chunk_block = &self.state.chunks[chunk_index];
362358
progress.rows += chunk_block.num_rows();
363359
progress.bytes += chunk_block.memory_size();
364-
365-
*chunk_mut = chunk_block.clone();
366-
367-
plain.push(chunk_block);
360+
plain.push(chunk_block.clone());
368361
}
369362
debug_assert!(matches!(
370363
*self.state.hash_table,

src/query/service/src/pipelines/processors/transforms/new_hash_join/memory/inner_join.rs

Lines changed: 11 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ use std::ops::Deref;
1616
use std::sync::Arc;
1717

1818
use databend_common_base::base::ProgressValues;
19-
use databend_common_catalog::table_context::TableContext;
2019
use databend_common_column::bitmap::Bitmap;
2120
use databend_common_exception::ErrorCode;
2221
use databend_common_exception::Result;
@@ -27,10 +26,10 @@ use databend_common_expression::DataBlock;
2726
use databend_common_expression::FilterExecutor;
2827
use databend_common_expression::FunctionContext;
2928
use databend_common_expression::HashMethodKind;
29+
use databend_common_settings::Settings;
3030

3131
use super::basic::BasicHashJoin;
3232
use super::basic_state::BasicHashJoinState;
33-
use super::LoopJoinStream;
3433
use crate::pipelines::processors::transforms::build_runtime_filter_packet;
3534
use crate::pipelines::processors::transforms::new_hash_join::hashtable::basic::ProbeStream;
3635
use crate::pipelines::processors::transforms::new_hash_join::hashtable::basic::ProbedRows;
@@ -41,10 +40,8 @@ use crate::pipelines::processors::transforms::new_hash_join::join::JoinStream;
4140
use crate::pipelines::processors::transforms::new_hash_join::performance::PerformanceContext;
4241
use crate::pipelines::processors::transforms::HashJoinHashTable;
4342
use crate::pipelines::processors::transforms::JoinRuntimeFilterPacket;
44-
use crate::pipelines::processors::transforms::NestedLoopDesc;
4543
use crate::pipelines::processors::transforms::RuntimeFiltersDesc;
4644
use crate::pipelines::processors::HashJoinDesc;
47-
use crate::sessions::QueryContext;
4845

4946
pub struct InnerHashJoin {
5047
pub(crate) basic_hash_join: BasicHashJoin,
@@ -53,35 +50,23 @@ pub struct InnerHashJoin {
5350
pub(crate) function_ctx: FunctionContext,
5451
pub(crate) basic_state: Arc<BasicHashJoinState>,
5552
pub(crate) performance_context: PerformanceContext,
56-
nested_loop_filter: Option<FilterExecutor>,
57-
nested_loop_field_reorder: Option<Vec<usize>>,
5853
}
5954

6055
impl InnerHashJoin {
6156
pub fn create(
62-
ctx: &QueryContext,
57+
settings: &Settings,
6358
function_ctx: FunctionContext,
6459
method: HashMethodKind,
6560
desc: Arc<HashJoinDesc>,
6661
state: Arc<BasicHashJoinState>,
62+
nested_loop_join_threshold: usize,
6763
) -> Result<Self> {
68-
let settings = ctx.get_settings();
6964
let block_size = settings.get_max_block_size()? as usize;
7065

7166
let context = PerformanceContext::create(block_size, desc.clone(), function_ctx.clone());
7267

73-
let (nested_loop_filter, nested_loop_field_reorder, nested_loop_join_threshold) =
74-
match desc.create_nested_loop_desc(&settings, &function_ctx)? {
75-
Some(NestedLoopDesc {
76-
filter,
77-
field_reorder,
78-
nested_loop_join_threshold,
79-
}) => (Some(filter), field_reorder, nested_loop_join_threshold),
80-
None => (None, None, 0),
81-
};
82-
8368
let basic_hash_join = BasicHashJoin::create(
84-
&settings,
69+
settings,
8570
function_ctx.clone(),
8671
method,
8772
desc.clone(),
@@ -95,8 +80,6 @@ impl InnerHashJoin {
9580
function_ctx,
9681
basic_state: state,
9782
performance_context: context,
98-
nested_loop_filter,
99-
nested_loop_field_reorder,
10083
})
10184
}
10285
}
@@ -131,23 +114,6 @@ impl Join for InnerHashJoin {
131114

132115
self.basic_hash_join.finalize_chunks();
133116

134-
match &*self.basic_state.hash_table {
135-
HashJoinHashTable::Null => {
136-
return Err(ErrorCode::AbortedQuery(
137-
"Aborted query, because the hash table is uninitialized.",
138-
))
139-
}
140-
HashJoinHashTable::NestedLoop(build_blocks) => {
141-
let nested = Box::new(LoopJoinStream::new(data, build_blocks));
142-
return Ok(InnerHashJoinFilterStream::create(
143-
nested,
144-
self.nested_loop_filter.as_mut().unwrap(),
145-
self.nested_loop_field_reorder.as_deref(),
146-
));
147-
}
148-
_ => (),
149-
}
150-
151117
let probe_keys = self.desc.probe_key(&data, &self.function_ctx)?;
152118

153119
let mut keys = DataBlock::new(probe_keys, data.num_rows());
@@ -175,7 +141,12 @@ impl Join for InnerHashJoin {
175141
&mut self.performance_context.probe_result,
176142
)
177143
}
178-
HashJoinHashTable::Null | HashJoinHashTable::NestedLoop(_) => unreachable!(),
144+
HashJoinHashTable::Null => {
145+
return Err(ErrorCode::AbortedQuery(
146+
"Aborted query, because the hash table is uninitialized.",
147+
));
148+
}
149+
HashJoinHashTable::NestedLoop(_) => unreachable!(),
179150
});
180151

181152
match &mut self.performance_context.filter_executor {
@@ -292,7 +263,7 @@ impl<'a> JoinStream for InnerHashJoinStream<'a> {
292263
}
293264
}
294265

295-
struct InnerHashJoinFilterStream<'a> {
266+
pub(super) struct InnerHashJoinFilterStream<'a> {
296267
inner: Box<dyn JoinStream + 'a>,
297268
filter_executor: &'a mut FilterExecutor,
298269
field_reorder: Option<&'a [usize]>,

src/query/service/src/pipelines/processors/transforms/new_hash_join/memory/nested_loop.rs

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,72 @@
1313
// limitations under the License.
1414

1515
use std::collections::VecDeque;
16+
use std::sync::Arc;
1617

18+
use databend_common_base::base::ProgressValues;
1719
use databend_common_exception::Result;
1820
use databend_common_expression::types::DataType;
1921
use databend_common_expression::BlockEntry;
2022
use databend_common_expression::DataBlock;
2123
use databend_common_expression::Scalar;
2224

25+
use super::inner_join::InnerHashJoinFilterStream;
26+
use crate::pipelines::processors::transforms::new_hash_join::join::EmptyJoinStream;
27+
use crate::pipelines::processors::transforms::BasicHashJoinState;
28+
use crate::pipelines::processors::transforms::HashJoinHashTable;
29+
use crate::pipelines::processors::transforms::Join;
30+
use crate::pipelines::processors::transforms::JoinRuntimeFilterPacket;
2331
use crate::pipelines::processors::transforms::JoinStream;
32+
use crate::pipelines::processors::transforms::NestedLoopDesc;
33+
use crate::pipelines::processors::transforms::RuntimeFiltersDesc;
2434

25-
pub struct LoopJoinStream<'a> {
35+
pub struct NestedLoopJoin<T> {
36+
inner: T,
37+
basic_state: Arc<BasicHashJoinState>,
38+
desc: NestedLoopDesc,
39+
}
40+
41+
impl<T> NestedLoopJoin<T> {
42+
pub fn create(inner: T, basic_state: Arc<BasicHashJoinState>, desc: NestedLoopDesc) -> Self {
43+
Self {
44+
inner,
45+
basic_state,
46+
desc,
47+
}
48+
}
49+
}
50+
51+
impl<T: Join> Join for NestedLoopJoin<T> {
52+
fn add_block(&mut self, data: Option<DataBlock>) -> Result<()> {
53+
self.inner.add_block(data)
54+
}
55+
56+
fn final_build(&mut self) -> Result<Option<ProgressValues>> {
57+
self.inner.final_build()
58+
}
59+
60+
fn build_runtime_filter(&self, desc: &RuntimeFiltersDesc) -> Result<JoinRuntimeFilterPacket> {
61+
self.inner.build_runtime_filter(desc)
62+
}
63+
64+
fn probe_block(&mut self, data: DataBlock) -> Result<Box<dyn JoinStream + '_>> {
65+
if data.is_empty() || *self.basic_state.build_rows == 0 {
66+
return Ok(Box::new(EmptyJoinStream));
67+
}
68+
let HashJoinHashTable::NestedLoop(build_blocks) = &*self.basic_state.hash_table else {
69+
return self.inner.probe_block(data);
70+
};
71+
72+
let nested = Box::new(LoopJoinStream::new(data, build_blocks));
73+
Ok(InnerHashJoinFilterStream::create(
74+
nested,
75+
&mut self.desc.filter,
76+
self.desc.field_reorder.as_deref(),
77+
))
78+
}
79+
}
80+
81+
struct LoopJoinStream<'a> {
2682
probe_rows: VecDeque<Vec<Scalar>>,
2783
probe_types: Vec<DataType>,
2884
build_blocks: &'a [DataBlock],

src/query/service/src/pipelines/processors/transforms/new_hash_join/transform_hash_join.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
use std::any::Any;
1616
use std::fmt::Debug;
1717
use std::fmt::Formatter;
18+
use std::marker::PhantomPinned;
1819
use std::sync::Arc;
1920

2021
use databend_common_exception::Result;
@@ -42,6 +43,7 @@ pub struct TransformHashJoin {
4243
stage_sync_barrier: Arc<Barrier>,
4344
projection: ColumnSet,
4445
rf_desc: Arc<RuntimeFiltersDesc>,
46+
_p: PhantomPinned,
4547
}
4648

4749
impl TransformHashJoin {
@@ -67,6 +69,7 @@ impl TransformHashJoin {
6769
finished: false,
6870
build_data: None,
6971
}),
72+
_p: PhantomPinned,
7073
}))
7174
}
7275
}
@@ -117,8 +120,7 @@ impl Processor for TransformHashJoin {
117120
}
118121
}
119122

120-
#[allow(clippy::missing_transmute_annotations)]
121-
fn process(&mut self) -> Result<()> {
123+
fn process<'a>(&'a mut self) -> Result<()> {
122124
match &mut self.stage {
123125
Stage::Finished => Ok(()),
124126
Stage::Build(state) => {
@@ -144,7 +146,9 @@ impl Processor for TransformHashJoin {
144146
if let Some(probe_data) = state.input_data.take() {
145147
let stream = self.join.probe_block(probe_data)?;
146148
// This is safe because both join and stream are properties of the struct.
147-
state.stream = Some(unsafe { std::mem::transmute(stream) });
149+
state.stream = Some(unsafe {
150+
std::mem::transmute::<Box<dyn JoinStream + 'a>, Box<dyn JoinStream>>(stream)
151+
});
148152
}
149153

150154
if let Some(mut stream) = state.stream.take() {
@@ -161,7 +165,11 @@ impl Processor for TransformHashJoin {
161165
if let Some(final_stream) = self.join.final_probe()? {
162166
state.initialize = true;
163167
// This is safe because both join and stream are properties of the struct.
164-
state.stream = Some(unsafe { std::mem::transmute(final_stream) });
168+
state.stream = Some(unsafe {
169+
std::mem::transmute::<Box<dyn JoinStream + 'a>, Box<dyn JoinStream>>(
170+
final_stream,
171+
)
172+
});
165173
} else {
166174
state.finished = true;
167175
}

0 commit comments

Comments
 (0)