Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 79 additions & 38 deletions Modules/_base64/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::cell::UnsafeCell;
use std::ffi::{c_char, c_int, c_void};
use std::ffi::{CStr, c_char, c_int, c_void};
use std::mem::MaybeUninit;
use std::ptr;
use std::slice;
Expand All @@ -12,6 +12,8 @@ use cpython_sys::PyBuffer_Release;
use cpython_sys::PyBytes_AsString;
use cpython_sys::PyBytes_FromStringAndSize;
use cpython_sys::PyErr_NoMemory;
use cpython_sys::PyErr_SetNone;
use cpython_sys::PyErr_SetObject;
use cpython_sys::PyErr_SetString;
use cpython_sys::PyExc_TypeError;
use cpython_sys::PyMethodDef;
Expand All @@ -22,6 +24,59 @@ use cpython_sys::PyModuleDef_Init;
use cpython_sys::PyObject;
use cpython_sys::PyObject_GetBuffer;

// Error Handling Abstraction

/// Zero-sized type indicating that a Python exception has been set.
/// Using this type will ensure `Result<&PyObject, ExecutedErr>` and `Result<PyRc, ExecutedErr>`
/// to be same size as `*mut PyObject`.
#[derive(Debug, Clone, Copy)]
pub struct ExecutedErr;

/// Enum representing different ways to set a Python exception.
///
/// This type is NOT stored in Result - it's immediately converted to
/// `ExecutedErr` via `.into()`, which triggers the actual C API call.
pub enum MakeErr {
SetString(*mut PyObject, *const c_char),
SetObject(*mut PyObject, *mut PyObject),
SetNone(*mut PyObject),
NoMemory,
}

impl MakeErr {
fn execute(self) -> ExecutedErr {
match self {
MakeErr::SetString(exc_type, msg) => {
unsafe { PyErr_SetString(exc_type, msg) };
}
MakeErr::SetObject(exc_type, value) => {
unsafe { PyErr_SetObject(exc_type, value) };
}
MakeErr::SetNone(exc_type) => {
unsafe { PyErr_SetNone(exc_type) };
}
MakeErr::NoMemory => {
unsafe { PyErr_NoMemory() };
}
}
ExecutedErr
}

#[inline]
pub fn type_error(msg: &CStr) -> Self {
Self::SetString(unsafe { PyExc_TypeError }, msg.as_ptr())
}
}

impl From<MakeErr> for ExecutedErr {
#[inline]
fn from(exc: MakeErr) -> Self {
exc.execute()
}
}

pub type PyResult<T> = Result<T, ExecutedErr>;

const PYBUF_SIMPLE: c_int = 0;
const PAD_BYTE: u8 = b'=';
const ENCODE_TABLE: [u8; 64] = *b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
Expand Down Expand Up @@ -82,11 +137,12 @@ struct BorrowedBuffer {
}

impl BorrowedBuffer {
fn from_object(obj: &PyObject) -> Result<Self, ()> {
fn from_object(obj: &PyObject) -> PyResult<Self> {
let mut view = MaybeUninit::<Py_buffer>::uninit();
let buffer = unsafe {
if PyObject_GetBuffer(obj.as_raw(), view.as_mut_ptr(), PYBUF_SIMPLE) != 0 {
return Err(());
// PyObject_GetBuffer already set the exception
return Err(ExecutedErr);
}
Self {
view: view.assume_init(),
Expand All @@ -102,6 +158,15 @@ impl BorrowedBuffer {
fn as_ptr(&self) -> *const u8 {
self.view.buf.cast::<u8>() as *const u8
}

fn as_slice(&self, msg: &CStr) -> PyResult<&[u8]> {
let len = self.len();
if len < 0 {
return Err(MakeErr::type_error(msg).into());
}
let slice = unsafe { slice::from_raw_parts(self.as_ptr(), len as usize) };
Ok(slice)
}
}

impl Drop for BorrowedBuffer {
Expand All @@ -122,12 +187,7 @@ pub unsafe extern "C" fn standard_b64encode(
nargs: Py_ssize_t,
) -> *mut PyObject {
if nargs != 1 {
unsafe {
PyErr_SetString(
PyExc_TypeError,
c"standard_b64encode() takes exactly one argument".as_ptr(),
);
}
MakeErr::type_error(c"standard_b64encode() takes exactly one argument").execute();
return ptr::null_mut();
}

Expand All @@ -140,51 +200,32 @@ pub unsafe extern "C" fn standard_b64encode(
}
}

fn standard_b64encode_impl(source: &PyObject) -> Result<*mut PyObject, ()> {
let buffer = match BorrowedBuffer::from_object(source) {
Ok(buf) => buf,
Err(_) => return Err(()),
};

let view_len = buffer.len();
if view_len < 0 {
unsafe {
PyErr_SetString(
PyExc_TypeError,
c"standard_b64encode() argument has negative length".as_ptr(),
);
}
return Err(());
}
fn standard_b64encode_impl(source: &PyObject) -> PyResult<*mut PyObject> {
let buffer = BorrowedBuffer::from_object(source)?;

let input_len = view_len as usize;
let input = unsafe { slice::from_raw_parts(buffer.as_ptr(), input_len) };
let input = buffer.as_slice(c"standard_b64encode() argument has negative length")?;

let Some(output_len) = encoded_output_len(input_len) else {
unsafe {
PyErr_NoMemory();
}
return Err(());
let Some(output_len) = encoded_output_len(input.len()) else {
return Err(MakeErr::NoMemory.into());
};

if output_len > isize::MAX as usize {
unsafe {
PyErr_NoMemory();
}
return Err(());
return Err(MakeErr::NoMemory.into());
}

let result = unsafe { PyBytes_FromStringAndSize(ptr::null(), output_len as Py_ssize_t) };
if result.is_null() {
return Err(());
// PyBytes_FromStringAndSize already set the exception
return Err(ExecutedErr);
}

let dest_ptr = unsafe { PyBytes_AsString(result) };
if dest_ptr.is_null() {
unsafe {
Py_DecRef(result);
}
return Err(());
// PyBytes_AsString already set the exception
return Err(ExecutedErr);
}
let dest = unsafe { slice::from_raw_parts_mut(dest_ptr.cast::<u8>(), output_len) };

Expand Down