Skip to content

Commit c318c0b

Browse files
committed
add to_2d_simple func
1 parent 29e7839 commit c318c0b

File tree

1 file changed

+42
-1
lines changed

1 file changed

+42
-1
lines changed

src/ndarrayext.rs

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,30 @@ pub fn to_2d<'a, T: 'a, D, I>(data: I, axis: Axis) -> Result<ArrayView2<'a, T>>
6969
format!("Cannot reshape {}-d array with shape {:?} by axis {} \
7070
to 2-d array with shape {:?}. Error: {}",
7171
ndim, shape, axis.0, new_shape, error)
72-
))
72+
)
73+
)
74+
}
75+
}
76+
77+
78+
pub fn to_2d_simple<'a, T: 'a, D>(data: ArrayView<'a, T, D>) -> Result<ArrayView2<'a, T>>
79+
where
80+
D: Dimension
81+
{
82+
let ndim = data.ndim();
83+
let shape = data.shape().to_vec();
84+
let new_shape = [shape[0..(ndim - 1)].iter().product(), shape[ndim - 1]];
85+
86+
match data.into_shape(new_shape) {
87+
Ok(data_2d) => Ok(data_2d),
88+
Err(err) => {
89+
return Err(
90+
ReshapeError(
91+
format!("Cannot reshape data array with shape {:?} to 2-d array with \
92+
shape {:?}. Error: {}", shape, new_shape, err)
93+
)
94+
)
95+
}
7396
}
7497
}
7598

@@ -231,6 +254,24 @@ mod tests {
231254
assert_eq!(to_2d(&a, Axis(2)).unwrap(), array![[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]);
232255
}
233256

257+
#[test]
258+
fn test_to_2d_simple_from_1d() {
259+
let a = array![1, 2, 3, 4];
260+
assert_eq!(to_2d_simple(a.view()).unwrap(), array![[1, 2, 3, 4]]);
261+
}
262+
263+
#[test]
264+
fn test_to_2d_simple_from_2d() {
265+
let a = array![[1, 2, 3, 4], [5, 6, 7, 8]];
266+
assert_eq!(to_2d_simple(a.view()).unwrap(), array![[1, 2, 3, 4], [5, 6, 7, 8]]);
267+
}
268+
269+
#[test]
270+
fn test_to_2d_simple_from_3d() {
271+
let a = array![[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]];
272+
assert_eq!(to_2d_simple(a.view()).unwrap(), array![[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]);
273+
}
274+
234275
#[test]
235276
fn test_from_2d_to_3d() {
236277
let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]];

0 commit comments

Comments
 (0)