Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 56 additions & 23 deletions datafusion/proto-common/src/from_proto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ use arrow::datatypes::{
DataType, Field, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema,
TimeUnit, UnionFields, UnionMode, i256,
};
use arrow::ipc::{reader::read_record_batch, root_as_message};
use arrow::ipc::{
convert::fb_to_schema,
reader::{read_dictionary, read_record_batch},
root_as_message,
writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions},
};

use datafusion_common::{
Column, ColumnStatistics, Constraint, Constraints, DFSchema, DFSchemaRef,
Expand Down Expand Up @@ -384,7 +389,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
Value::Float32Value(v) => Self::Float32(Some(*v)),
Value::Float64Value(v) => Self::Float64(Some(*v)),
Value::Date32Value(v) => Self::Date32(Some(*v)),
// ScalarValue::List is serialized using arrow IPC format
// Nested ScalarValue types are serialized using arrow IPC format
Value::ListValue(v)
| Value::FixedSizeListValue(v)
| Value::LargeListValue(v)
Expand All @@ -401,55 +406,83 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
schema_ref.try_into()?
} else {
return Err(Error::General(
"Invalid schema while deserializing ScalarValue::List"
"Invalid schema while deserializing nested ScalarValue"
.to_string(),
));
};

// IPC dictionary batch IDs are assigned when encoding the schema, but our protobuf
// `Schema` doesn't preserve those IDs. Reconstruct them deterministically by
// round-tripping the schema through IPC.
let schema: Schema = {
let ipc_gen = IpcDataGenerator {};
let write_options = IpcWriteOptions::default();
let mut dict_tracker = DictionaryTracker::new(false);
let encoded_schema = ipc_gen.schema_to_bytes_with_dictionary_tracker(
&schema,
&mut dict_tracker,
&write_options,
);
let message =
root_as_message(encoded_schema.ipc_message.as_slice()).map_err(
|e| {
Error::General(format!(
"Error IPC schema message while deserializing nested ScalarValue: {e}"
))
},
)?;
let ipc_schema = message.header_as_schema().ok_or_else(|| {
Error::General(
"Unexpected message type deserializing nested ScalarValue schema"
.to_string(),
)
})?;
fb_to_schema(ipc_schema)
};

let message = root_as_message(ipc_message.as_slice()).map_err(|e| {
Error::General(format!(
"Error IPC message while deserializing ScalarValue::List: {e}"
"Error IPC message while deserializing nested ScalarValue: {e}"
))
})?;
let buffer = Buffer::from(arrow_data.as_slice());

let ipc_batch = message.header_as_record_batch().ok_or_else(|| {
Error::General(
"Unexpected message type deserializing ScalarValue::List"
"Unexpected message type deserializing nested ScalarValue"
.to_string(),
)
})?;

let dict_by_id: HashMap<i64,ArrayRef> = dictionaries.iter().map(|protobuf::scalar_nested_value::Dictionary { ipc_message, arrow_data }| {
let mut dict_by_id: HashMap<i64, ArrayRef> = HashMap::new();
for protobuf::scalar_nested_value::Dictionary {
ipc_message,
arrow_data,
} in dictionaries
{
let message = root_as_message(ipc_message.as_slice()).map_err(|e| {
Error::General(format!(
"Error IPC message while deserializing ScalarValue::List dictionary message: {e}"
"Error IPC message while deserializing nested ScalarValue dictionary message: {e}"
))
})?;
let buffer = Buffer::from(arrow_data.as_slice());

let dict_batch = message.header_as_dictionary_batch().ok_or_else(|| {
Error::General(
"Unexpected message type deserializing ScalarValue::List dictionary message"
"Unexpected message type deserializing nested ScalarValue dictionary message"
.to_string(),
)
})?;

let id = dict_batch.id();

let record_batch = read_record_batch(
read_dictionary(
&buffer,
dict_batch.data().unwrap(),
Arc::new(schema.clone()),
&Default::default(),
None,
dict_batch,
&schema,
&mut dict_by_id,
&message.version(),
)?;

let values: ArrayRef = Arc::clone(record_batch.column(0));

Ok((id, values))
}).collect::<datafusion_common::Result<HashMap<_, _>>>()?;
)
.map_err(|e| arrow_datafusion_err!(e))
.map_err(|e| e.context("Decoding nested ScalarValue dictionary"))?;
}

let record_batch = read_record_batch(
&buffer,
Expand All @@ -460,7 +493,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
&message.version(),
)
.map_err(|e| arrow_datafusion_err!(e))
.map_err(|e| e.context("Decoding ScalarValue::List Value"))?;
.map_err(|e| e.context("Decoding nested ScalarValue value"))?;
let arr = record_batch.column(0);
match value {
Value::ListValue(_) => {
Expand Down
13 changes: 10 additions & 3 deletions datafusion/proto-common/src/to_proto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1010,21 +1010,28 @@ fn create_proto_scalar<I, T: FnOnce(&I) -> protobuf::scalar_value::Value>(
Ok(protobuf::ScalarValue { value: Some(value) })
}

// ScalarValue::List / FixedSizeList / LargeList / Struct / Map are serialized using
// Nested ScalarValue types (List / FixedSizeList / LargeList / Struct / Map) are serialized using
// Arrow IPC messages as a single column RecordBatch
fn encode_scalar_nested_value(
arr: ArrayRef,
val: &ScalarValue,
) -> Result<protobuf::ScalarValue, Error> {
let batch = RecordBatch::try_from_iter(vec![("field_name", arr)]).map_err(|e| {
Error::General(format!(
"Error creating temporary batch while encoding ScalarValue::List: {e}"
"Error creating temporary batch while encoding nested ScalarValue: {e}"
))
})?;

let ipc_gen = IpcDataGenerator {};
let mut dict_tracker = DictionaryTracker::new(false);
let write_options = IpcWriteOptions::default();
// The IPC writer requires pre-allocated dictionary IDs (normally assigned when
// serializing the schema). Populate `dict_tracker` by encoding the schema first.
ipc_gen.schema_to_bytes_with_dictionary_tracker(
batch.schema().as_ref(),
&mut dict_tracker,
&write_options,
);
let mut compression_context = CompressionContext::default();
let (encoded_dictionaries, encoded_message) = ipc_gen
.encode(
Expand All @@ -1034,7 +1041,7 @@ fn encode_scalar_nested_value(
&mut compression_context,
)
.map_err(|e| {
Error::General(format!("Error encoding ScalarValue::List as IPC: {e}"))
Error::General(format!("Error encoding nested ScalarValue as IPC: {e}"))
})?;

let schema: protobuf::Schema = batch.schema().try_into()?;
Expand Down
19 changes: 19 additions & 0 deletions datafusion/proto/tests/cases/roundtrip_physical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2564,3 +2564,22 @@ fn custom_proto_converter_intercepts() -> Result<()> {

Ok(())
}

#[test]
fn roundtrip_call_null_scalar_struct_dict() -> Result<()> {
let data_type = DataType::Struct(Fields::from(vec![Field::new(
"item",
DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)),
true,
)]));

let schema = Arc::new(Schema::new(vec![Field::new("a", data_type.clone(), true)]));
let scan = Arc::new(EmptyExec::new(Arc::clone(&schema)));
let scalar = lit(ScalarValue::try_from(data_type)?);
let filter = Arc::new(FilterExec::try_new(
Arc::new(BinaryExpr::new(scalar, Operator::Eq, col("a", &schema)?)),
scan,
)?);

roundtrip_test(filter)
}