Skip to content

Commit 7ec122c

Browse files
committed
poc: fix mangled errors
1 parent c609dfa commit 7ec122c

File tree

5 files changed

+56
-6
lines changed

5 files changed

+56
-6
lines changed

python/tests/test_catalog.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,12 @@ def table_exist(self, name: str) -> bool:
7676
return name in self.tables
7777

7878

79+
class CustomErrorSchemaProvider(CustomSchemaProvider):
80+
def table(self, name: str) -> Table | None:
81+
message = f"{name} is not an acceptable name"
82+
raise ValueError(message)
83+
84+
7985
class CustomCatalogProvider(dfn.catalog.CatalogProvider):
8086
def __init__(self):
8187
self.schemas = {"my_schema": CustomSchemaProvider()}
@@ -164,6 +170,33 @@ def test_python_table_provider(ctx: SessionContext):
164170
assert schema.table_names() == {"table4"}
165171

166172

173+
def test_exception_not_mangled(ctx: SessionContext):
174+
"""Test registering all python providers and running a query against them."""
175+
176+
catalog_name = "custom_catalog"
177+
schema_name = "custom_schema"
178+
179+
ctx.register_catalog_provider(catalog_name, CustomCatalogProvider())
180+
181+
catalog = ctx.catalog(catalog_name)
182+
183+
# Clean out previous schemas if they exist so we can start clean
184+
for schema_name in catalog.schema_names():
185+
catalog.deregister_schema(schema_name, cascade=False)
186+
187+
catalog.register_schema(schema_name, CustomErrorSchemaProvider())
188+
189+
schema = catalog.schema(schema_name)
190+
191+
for table_name in schema.table_names():
192+
schema.deregister_table(table_name)
193+
194+
schema.register_table("test_table", create_dataset())
195+
196+
with pytest.raises(ValueError, match="^test_table is not an acceptable name$"):
197+
ctx.sql(f"select * from {catalog_name}.{schema_name}.test_table")
198+
199+
167200
def test_in_end_to_end_python_providers(ctx: SessionContext):
168201
"""Test registering all python providers and running a query against them."""
169202

python/tests/test_sql.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@
2929

3030

3131
def test_no_table(ctx):
32-
with pytest.raises(Exception, match="DataFusion error"):
32+
with pytest.raises(
33+
ValueError,
34+
match="^Error during planning: table 'datafusion.public.b' not found$",
35+
):
3336
ctx.sql("SELECT a FROM b").collect()
3437

3538

src/catalog.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ impl RustWrappedPySchemaProvider {
323323
match py_table.extract::<PyTable>() {
324324
Ok(py_table) => Ok(Some(py_table.table)),
325325
Err(_) => {
326-
let ds = Dataset::new(&py_table, py).map_err(py_datafusion_err)?;
326+
let ds = Dataset::new(&py_table, py)?;
327327
Ok(Some(Arc::new(ds) as Arc<dyn TableProvider>))
328328
}
329329
}
@@ -360,7 +360,8 @@ impl SchemaProvider for RustWrappedPySchemaProvider {
360360
&self,
361361
name: &str,
362362
) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>, DataFusionError> {
363-
self.table_inner(name).map_err(to_datafusion_err)
363+
self.table_inner(name)
364+
.map_err(|e| DataFusionError::External(Box::new(e)))
364365
}
365366

366367
fn register_table(

src/context.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@ use pyo3::prelude::*;
3434
use crate::catalog::{PyCatalog, PyTable, RustWrappedPyCatalogProvider};
3535
use crate::dataframe::PyDataFrame;
3636
use crate::dataset::Dataset;
37-
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult};
37+
use crate::errors::{
38+
from_datafusion_error, py_datafusion_err, to_datafusion_err, PyDataFusionResult,
39+
};
3840
use crate::expr::sort_expr::PySortExpr;
3941
use crate::physical_plan::PyExecutionPlan;
4042
use crate::record_batch::PyRecordBatchStream;
@@ -433,9 +435,9 @@ impl PySessionContext {
433435
}
434436

435437
/// Returns a PyDataFrame whose plan corresponds to the SQL statement.
436-
pub fn sql(&mut self, query: &str, py: Python) -> PyDataFusionResult<PyDataFrame> {
438+
pub fn sql(&mut self, query: &str, py: Python) -> PyResult<PyDataFrame> {
437439
let result = self.ctx.sql(query);
438-
let df = wait_for_future(py, result)??;
440+
let df = wait_for_future(py, result)?.map_err(from_datafusion_error)?;
439441
Ok(PyDataFrame::new(df))
440442
}
441443

src/errors.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use std::fmt::Debug;
2222
use datafusion::arrow::error::ArrowError;
2323
use datafusion::error::DataFusionError as InnerDataFusionError;
2424
use prost::EncodeError;
25+
use pyo3::exceptions::PyValueError;
2526
use pyo3::{exceptions::PyException, PyErr};
2627

2728
pub type PyDataFusionResult<T> = std::result::Result<T, PyDataFusionError>;
@@ -95,3 +96,13 @@ pub fn py_unsupported_variant_err(e: impl Debug) -> PyErr {
9596
pub fn to_datafusion_err(e: impl Debug) -> InnerDataFusionError {
9697
InnerDataFusionError::Execution(format!("{e:?}"))
9798
}
99+
100+
pub fn from_datafusion_error(err: InnerDataFusionError) -> PyErr {
101+
match err {
102+
InnerDataFusionError::External(boxed) => match boxed.downcast::<PyErr>() {
103+
Ok(py_err) => *py_err,
104+
Err(original_boxed) => PyValueError::new_err(format!("{original_boxed}")),
105+
},
106+
_ => PyValueError::new_err(format!("{err}")),
107+
}
108+
}

0 commit comments

Comments
 (0)