Skip to content

Commit c1c8f4a

Browse files
committed
implement ndg evaluate
1 parent ed6e650 commit c1c8f4a

File tree

9 files changed

+138
-26
lines changed

9 files changed

+138
-26
lines changed

src/ndg.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
mod validate;
22
mod make;
33
mod evaluate;
4+
mod util;
45

56
use ndarray::{
67
NdFloat,
@@ -191,7 +192,10 @@ impl<'a, T, D> GridCubicSmoothingSpline<'a, T, D>
191192
///
192193
pub fn evaluate(&self, xi: &[ArrayView1<'a, T>]) -> Result<Array<T, D>> {
193194
let xi = xi.to_vec();
194-
self.evaluate_spline(&xi)
195+
self.evaluate_validate(&xi)?;
196+
let yi = self.evaluate_spline(&xi);
197+
198+
Ok(yi)
195199
}
196200

197201
/// Returns the ref to smoothing parameters vector or None

src/ndg/evaluate.rs

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
use ndarray::{NdFloat, Dimension, Array, ArrayView1};
22
use almost::AlmostEqual;
33

4-
use crate::Result;
5-
use super::{NdGridSpline, GridCubicSmoothingSpline};
4+
use crate::{NdSpline, ndarrayext::to_2d_simple};
5+
use super::{NdGridSpline, GridCubicSmoothingSpline, util::permute_axes};
66

77

88
impl<'a, T, D> NdGridSpline<'a, T, D>
@@ -12,7 +12,42 @@ impl<'a, T, D> NdGridSpline<'a, T, D>
1212
{
1313
/// Implements evaluating the spline on the given mesh of Xi-sites
1414
pub(super) fn evaluate_spline(&self, xi: &[ArrayView1<'a, T>]) -> Array<T, D> {
15-
unimplemented!();
15+
let mut coeffs = self.coeffs.to_owned();
16+
let mut coeffs_shape = coeffs.shape().to_vec();
17+
18+
let ndim_m1 = self.ndim - 1;
19+
let permuted_axes = permute_axes::<D>(self.ndim);
20+
21+
for ax in (0..self.ndim).rev() {
22+
let xi_ax = xi[ax];
23+
24+
let coeffs_2d = {
25+
let coeffs_2d = to_2d_simple(coeffs.view()).unwrap();
26+
27+
NdSpline::evaluate_spline(
28+
self.order[ax],
29+
self.pieces[ax],
30+
self.breaks[ax],
31+
coeffs_2d,
32+
xi_ax,
33+
)
34+
};
35+
36+
let mut shape = D::zeros(self.ndim);
37+
shape[ndim_m1] = xi_ax.len();
38+
for i in 0..ndim_m1 {
39+
shape[i] = coeffs_shape[i];
40+
}
41+
42+
coeffs = coeffs_2d
43+
.into_shape(shape).unwrap()
44+
.permuted_axes(permuted_axes.clone())
45+
.to_owned();
46+
47+
coeffs_shape = coeffs.shape().to_vec();
48+
}
49+
50+
coeffs
1651
}
1752
}
1853

@@ -22,7 +57,7 @@ impl<'a, T, D> GridCubicSmoothingSpline<'a, T, D>
2257
T: NdFloat + AlmostEqual + Default,
2358
D: Dimension
2459
{
25-
pub(super) fn evaluate_spline(&self, xi: &[ArrayView1<'a, T>]) -> Result<Array<T, D>> {
26-
unimplemented!();
60+
pub(super) fn evaluate_spline(&self, xi: &[ArrayView1<'a, T>]) -> Array<T, D> {
61+
self.spline.as_ref().unwrap().evaluate_spline(&xi)
2762
}
2863
}

src/ndg/make.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use almost::AlmostEqual;
44
use crate::{Result, CubicSmoothingSpline};
55
use crate::ndarrayext::to_2d_simple;
66

7-
use super::{GridCubicSmoothingSpline, NdGridSpline};
7+
use super::{GridCubicSmoothingSpline, NdGridSpline, util::permute_axes};
88

99

1010
impl<'a, T, D> GridCubicSmoothingSpline<'a, T, D>
@@ -22,11 +22,7 @@ impl<'a, T, D> GridCubicSmoothingSpline<'a, T, D>
2222

2323
let mut smooth: Vec<Option<T>> = vec![None; ndim];
2424

25-
let mut permute_axes = D::zeros(ndim);
26-
permute_axes[0] = ndim_m1;
27-
for ax in 0..ndim_m1 {
28-
permute_axes[ax+1] = ax;
29-
}
25+
let permuted_axes = permute_axes::<D>(ndim);
3026

3127
for ax in (0..ndim).rev() {
3228
let x = breaks[ax].view();
@@ -53,7 +49,7 @@ impl<'a, T, D> GridCubicSmoothingSpline<'a, T, D>
5349

5450
spline.coeffs()
5551
.into_shape(new_shape).unwrap()
56-
.permuted_axes(permute_axes.clone())
52+
.permuted_axes(permuted_axes.clone())
5753
.to_owned()
5854
};
5955

src/ndg/util.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
use ndarray::Dimension;
2+
3+
4+
pub(super) fn permute_axes<D>(ndim: usize) -> D
5+
where
6+
D: Dimension
7+
{
8+
let mut permute_axes = D::zeros(ndim);
9+
10+
permute_axes[0] = ndim - 1;
11+
for ax in 0..(ndim - 1) {
12+
permute_axes[ax + 1] = ax;
13+
}
14+
15+
permute_axes
16+
}

src/ndg/validate.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,22 @@ impl<'a, T, D> GridCubicSmoothingSpline<'a, T, D>
2525

2626
Ok(())
2727
}
28+
29+
pub(super) fn evaluate_validate(&self, xi: &[ArrayView1<'a, T>]) -> Result<()> {
30+
let x_len = self.x.len();
31+
let xi_len = xi.len();
32+
33+
if xi_len != x_len {
34+
return Err(
35+
InvalidInputData(
36+
format!("The number of `xi` vectors ({}) is not equal to the number of dimensions ({})",
37+
xi_len, x_len)
38+
)
39+
)
40+
}
41+
42+
Ok(())
43+
}
2844
}
2945

3046

src/umv.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,13 @@ impl<'a, T> NdSpline<'a, T>
9494

9595
/// Evaluates the spline on the given data sites
9696
pub fn evaluate(&self, xi: ArrayView1<'a, T>) -> Array2<T> {
97-
self.evaluate_spline(xi)
97+
Self::evaluate_spline(
98+
self.order,
99+
self.pieces,
100+
self.breaks.view(),
101+
self.coeffs.view(),
102+
xi,
103+
)
98104
}
99105
}
100106

@@ -278,7 +284,7 @@ impl<'a, T, D> CubicSmoothingSpline<'a, T, D>
278284
/// - If reshaping Y data to 2-d view has failed
279285
///
280286
pub fn make(mut self) -> Result<Self> {
281-
self.make_validate_data()?;
287+
self.make_validate()?;
282288
self.make_spline()?;
283289
Ok(self)
284290
}
@@ -294,7 +300,7 @@ impl<'a, T, D> CubicSmoothingSpline<'a, T, D>
294300
where X: AsArray<'a, T>
295301
{
296302
let xi = xi.into();
297-
self.evaluate_validate_data(xi)?;
303+
self.evaluate_validate(xi)?;
298304

299305
let yi = self.evaluate_spline(xi)?;
300306
Ok(yi)

src/umv/evaluate.rs

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use ndarray::{NdFloat, Dimension, Array, Array1, Array2, ArrayView1, Axis, s, stack};
1+
use ndarray::{NdFloat, Dimension, Array, Array1, Array2, ArrayView1, Axis, s, stack, ArrayView2};
22
use almost::AlmostEqual;
33

44
use crate::{Result, ndarrayext};
@@ -9,10 +9,16 @@ impl<'a, T> NdSpline<'a, T>
99
where T: NdFloat + AlmostEqual
1010
{
1111
/// Implements evaluating the spline on the given mesh of Xi-sites
12-
pub(super) fn evaluate_spline(&self, xi: ArrayView1<'a, T>) -> Array2<T> {
12+
pub(crate) fn evaluate_spline(
13+
order: usize,
14+
pieces: usize,
15+
breaks: ArrayView1<'_, T>,
16+
coeffs: ArrayView2<'_, T>,
17+
xi: ArrayView1<'a, T>) -> Array2<T>
18+
{
1319

1420
let edges = {
15-
let mesh = self.breaks.slice(s![1..-1]);
21+
let mesh = breaks.slice(s![1..-1]);
1622
let one = Array1::<T>::ones((1, ));
1723
let left_bound = &one * T::neg_infinity();
1824
let right_bound = &one * T::infinity();
@@ -24,7 +30,7 @@ impl<'a, T> NdSpline<'a, T>
2430

2531
// Go to local coordinates
2632
let xi = {
27-
let indexed_breaks = indices.mapv(|i| self.breaks[i]);
33+
let indexed_breaks = indices.mapv(|i| breaks[i]);
2834
&xi - &indexed_breaks
2935
};
3036

@@ -35,7 +41,7 @@ impl<'a, T> NdSpline<'a, T>
3541
let get_indexed_coeffs = |inds: &Array1<usize>| {
3642
// Returns Nx1 2-d array of coeffs by given index
3743
let coeffs_by_index = |&index| {
38-
self.coeffs.slice(s![.., index]).insert_axis(Axis(1))
44+
coeffs.slice(s![.., index]).insert_axis(Axis(1))
3945
};
4046

4147
// Get the M-sized vector of coeffs values Nx1 arrays
@@ -51,8 +57,8 @@ impl<'a, T> NdSpline<'a, T>
5157
// Vectorized computing the spline pieces (polynoms) on the given data sites
5258
let mut values = get_indexed_coeffs(&indices);
5359

54-
for _ in 1..self.order {
55-
indices += self.pieces;
60+
for _ in 1..order {
61+
indices += pieces;
5662
values = values * &xi + get_indexed_coeffs(&indices);
5763
}
5864

@@ -75,7 +81,7 @@ impl<'a, T, D> CubicSmoothingSpline<'a, T, D>
7581
shape[i] = s
7682
}
7783

78-
let yi_2d = self.spline.as_ref().unwrap().evaluate_spline(xi);
84+
let yi_2d = self.spline.as_ref().unwrap().evaluate(xi);
7985
let yi = ndarrayext::from_2d(&yi_2d, shape, axis)?.to_owned();
8086

8187
Ok(yi)

src/umv/validate.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use crate::{
1919
impl<'a, T, D> CubicSmoothingSpline<'a, T, D>
2020
where T: NdFloat + Default + AlmostEqual, D: Dimension
2121
{
22-
pub(super) fn make_validate_data(&self) -> Result<()> {
22+
pub(super) fn make_validate(&self) -> Result<()> {
2323
let x_size = self.x.len();
2424

2525
if x_size < 2 {
@@ -80,7 +80,15 @@ impl<'a, T, D> CubicSmoothingSpline<'a, T, D>
8080
Ok(())
8181
}
8282

83-
pub(super) fn evaluate_validate_data(&self, xi: ArrayView1<'a, T>) -> Result<()> {
83+
pub(super) fn evaluate_validate(&self, xi: ArrayView1<'a, T>) -> Result<()> {
84+
if xi.len() < 1 {
85+
return Err(
86+
InvalidInputData(
87+
"The size of data vectors must be greater or equal to 1".to_string()
88+
)
89+
)
90+
}
91+
8492
if self.spline.is_none() {
8593
return Err(
8694
InvalidInputData(

tests/ndg_evaluate.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
use ndarray::array;
2+
use approx::assert_abs_diff_eq;
3+
4+
use csaps::GridCubicSmoothingSpline;
5+
6+
7+
#[test]
8+
fn test_make_surface() {
9+
let x0 = array![1., 2., 3.];
10+
let x1 = array![1., 2., 3., 4.];
11+
12+
let x = vec![x0.view(), x1.view()];
13+
14+
let y = array![
15+
[1., 2., 3., 4.],
16+
[5., 6., 7., 8.],
17+
[9., 10., 11., 12.],
18+
];
19+
20+
let yi = GridCubicSmoothingSpline::new(&x, &y)
21+
.make().unwrap()
22+
.evaluate(&x).unwrap();
23+
24+
assert_abs_diff_eq!(yi, y);
25+
}

0 commit comments

Comments
 (0)