Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions pytests/test_dag_cbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,19 @@ def test_recursion_limit_exceed_on_nested_maps() -> None:
libipld.decode_dag_cbor(dag_cbor)

assert 'in DAG-CBOR decoding' in str(exc_info.value)


def test_dab_cbor_decode_map_int_key() -> None:
dag_cbor = bytes.fromhex('a10000')
with pytest.raises(ValueError) as exc_info:
libipld.decode_dag_cbor(dag_cbor)

assert 'Map keys must be strings' in str(exc_info.value)


def test_dab_cbor_encode_map_int_key() -> None:
obj = {0: 'value'}
with pytest.raises(ValueError) as exc_info:
libipld.encode_dag_cbor(obj)

assert 'Map keys must be strings' in str(exc_info.value)
20 changes: 13 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,24 @@ fn map_key_cmp(a: &Vec<u8>, b: &Vec<u8>) -> std::cmp::Ordering {
}
}

fn sort_map_keys(keys: &Bound<PyList>, len: usize) -> Vec<(PyBackedStr, usize)> {
fn sort_map_keys(keys: &Bound<PyList>, len: usize) -> Result<Vec<(PyBackedStr, usize)>> {
// Returns key and index.
let mut keys_str = Vec::with_capacity(len);
for i in 0..len {
let item = keys.get_item(i).unwrap();
let key = item.downcast::<PyString>().unwrap().to_owned();
let backed_str = PyBackedStr::try_from(key).unwrap();
let item = keys.get_item(i)?;
let key = match item.downcast::<PyString>() {
Ok(k) => k.to_owned(),
Err(_) => return Err(anyhow!("Map keys must be strings")),
};
let backed_str = match PyBackedStr::try_from(key) {
Ok(bs) => bs,
Err(_) => return Err(anyhow!("Failed to convert PyString to PyBackedStr")),
};
keys_str.push((backed_str, i));
}

if keys_str.len() < 2 {
return keys_str;
return Ok(keys_str);
}

keys_str.sort_by(|a, b| {
Expand All @@ -78,7 +84,7 @@ fn sort_map_keys(keys: &Bound<PyList>, len: usize) -> Vec<(PyBackedStr, usize)>
}
});

keys_str
Ok(keys_str)
}

fn get_bytes_from_py_any<'py>(obj: &'py Bound<'py, PyAny>) -> PyResult<&'py [u8]> {
Expand Down Expand Up @@ -252,7 +258,7 @@ fn encode_dag_cbor_from_pyobject<'py, W: Write>(
Ok(())
} else if let Ok(map) = obj.downcast::<PyDict>() {
let len = map.len();
let keys = sort_map_keys(&map.keys(), len);
let keys = sort_map_keys(&map.keys(), len)?;
let values = map.values();

encode::write_u64(w, MajorKind::Map, len as u64)?;
Expand Down
Loading