Skip to content

Commit 976dae1

Browse files
committed
implement ndg make (currently, it does not work properly)
1 parent 5ca08cd commit 976dae1

File tree

4 files changed

+83
-58
lines changed

4 files changed

+83
-58
lines changed

src/ndg.rs

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,10 @@ pub struct GridCubicSmoothingSpline<'a, T, D>
111111
y: ArrayView<'a, T, D>,
112112

113113
/// The optional data weights
114-
weights: Option<Vec<Option<ArrayView1<'a, T>>>>,
114+
weights: Vec<Option<ArrayView1<'a, T>>>,
115115

116116
/// The optional smoothing parameter
117-
smooth: Option<Vec<Option<T>>>,
117+
smooth: Vec<Option<T>>,
118118

119119
/// `NdSpline` struct with computed spline
120120
spline: Option<NdGridSpline<'a, T, D>>
@@ -131,11 +131,13 @@ impl<'a, T, D> GridCubicSmoothingSpline<'a, T, D>
131131
where
132132
Y: AsArray<'a, T, D>
133133
{
134+
let ndim = x.len();
135+
134136
GridCubicSmoothingSpline {
135137
x: x.to_vec(),
136138
y: y.into(),
137-
weights: None,
138-
smooth: None,
139+
weights: vec![None; ndim],
140+
smooth: vec![None; ndim],
139141
spline: None,
140142
}
141143
}
@@ -146,7 +148,7 @@ impl<'a, T, D> GridCubicSmoothingSpline<'a, T, D>
146148
///
147149
pub fn with_weights(mut self, weights: &[Option<ArrayView1<'a, T>>]) -> Self {
148150
self.invalidate();
149-
self.weights = Some(weights.to_vec());
151+
self.weights = weights.to_vec();
150152
self
151153
}
152154

@@ -162,7 +164,7 @@ impl<'a, T, D> GridCubicSmoothingSpline<'a, T, D>
162164
///
163165
pub fn with_smooth(mut self, smooth: &[Option<T>]) -> Self {
164166
self.invalidate();
165-
self.smooth = Some(smooth.to_vec());
167+
self.smooth = smooth.to_vec();
166168
self
167169
}
168170

@@ -193,8 +195,8 @@ impl<'a, T, D> GridCubicSmoothingSpline<'a, T, D>
193195
}
194196

195197
/// Returns the ref to smoothing parameters vector or None
196-
pub fn smooth(&self) -> Option<&Vec<Option<T>>> {
197-
self.smooth.as_ref()
198+
pub fn smooth(&self) -> &Vec<Option<T>> {
199+
&self.smooth
198200
}
199201

200202
/// Returns ref to `NdGridSpline` struct with data of computed spline or None

src/ndg/make.rs

Lines changed: 43 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
use ndarray::{NdFloat, Dimension};
22
use almost::AlmostEqual;
33

4-
use crate::{Result, CsapsError::ReshapeError, CubicSmoothingSpline};
4+
use crate::{Result, CubicSmoothingSpline};
5+
use crate::ndarrayext::to_2d_simple;
6+
57
use super::{GridCubicSmoothingSpline, NdGridSpline};
68

79

@@ -15,10 +17,10 @@ impl<'a, T, D> GridCubicSmoothingSpline<'a, T, D>
1517
let ndim_m1 = ndim - 1;
1618

1719
let breaks = self.x.to_vec();
18-
let mut coeffs = self.y.view();
20+
let mut coeffs = self.y.to_owned();
1921
let mut coeffs_shape = coeffs.shape().to_vec();
2022

21-
let mut smooth: Vec<Option<T>> = Vec::new();
23+
let mut smooth: Vec<Option<T>> = vec![None; ndim];
2224

2325
let mut permute_axes = D::zeros(ndim);
2426
permute_axes[0] = ndim_m1;
@@ -28,41 +30,47 @@ impl<'a, T, D> GridCubicSmoothingSpline<'a, T, D>
2830

2931
for ax in (0..ndim).rev() {
3032
let x = breaks[ax].view();
33+
let y = to_2d_simple(coeffs.view())?;
34+
35+
let weights = self.weights[ax].map(|v| v.reborrow());
36+
let s = self.smooth[ax];
37+
38+
println!("\nx shape: {:?}", x.shape());
39+
println!("y shape: {:?}", y.shape());
40+
41+
println!("\nx: {:?}", x);
42+
println!("y: {:?}", y);
43+
44+
let sp = CubicSmoothingSpline::new(x, y)
45+
.with_optional_weights(weights)
46+
.with_optional_smooth(s)
47+
.make()?;
48+
49+
smooth[ax] = sp.smooth();
50+
51+
coeffs = {
52+
let spline = sp.spline().unwrap();
53+
54+
coeffs_shape[ndim_m1] = spline.pieces() * spline.order();
55+
let mut new_shape = D::zeros(ndim);
56+
for (ax, &sz) in coeffs_shape.iter().enumerate() {
57+
new_shape[ax] = sz
58+
}
59+
60+
spline.coeffs()
61+
.into_shape(new_shape).unwrap()
62+
.permuted_axes(permute_axes.clone())
63+
.to_owned()
64+
};
65+
66+
coeffs_shape = coeffs.shape().to_vec();
3167

32-
if ndim > 2 {
33-
let coeffs_2d = {
34-
let shape = coeffs.shape().to_vec();
35-
let new_shape = [shape[0..ndim_m1].iter().product(), shape[ndim_m1]];
36-
37-
match coeffs.view().into_shape(new_shape) {
38-
Ok(coeffs_2d) => coeffs_2d,
39-
Err(err) => {
40-
return Err(
41-
ReshapeError(
42-
format!("Cannot reshape data array with shape {:?} to 2-d array with shape {:?}\n{}",
43-
shape, new_shape, err)
44-
)
45-
)
46-
}
47-
}
48-
};
49-
50-
// CubicSmoothingSpline::new(x, coeffs_2d.view())
51-
// .make()?
52-
// .spline().unwrap()
53-
// .coeffs().to_owned()
54-
//
55-
// } else {
56-
//
57-
// CubicSmoothingSpline::new(x, coeffs.view())
58-
// .make()?
59-
// .spline().unwrap()
60-
// .coeffs().to_owned()
61-
}
68+
println!("\ncoeffs shape: {:?}", coeffs.shape());
69+
println!("coeffs: {:?}", coeffs);
6270
}
6371

64-
self.smooth = Some(smooth);
65-
self.spline = Some(NdGridSpline::new(breaks, coeffs.to_owned()));
72+
self.smooth = smooth;
73+
self.spline = Some(NdGridSpline::new(breaks, coeffs));
6674

6775
Ok(())
6876
}

src/ndg/validate.rs

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,8 @@ impl<'a, T, D> GridCubicSmoothingSpline<'a, T, D>
2020
{
2121
pub(super) fn make_validate(&self) -> Result<()> {
2222
validate_xy(&self.x, self.y.view())?;
23-
24-
if let Some(weights) = self.weights.as_ref() {
25-
validate_weights(&self.x, weights)?;
26-
}
27-
28-
if let Some(smooth) = self.smooth.as_ref() {
29-
validate_smooth(smooth)?;
30-
}
23+
validate_weights(&self.x, &self.weights)?;
24+
validate_smooth(&self.x, &self.smooth)?;
3125

3226
Ok(())
3327
}
@@ -89,7 +83,7 @@ pub(super) fn validate_weights<T>(x: &[ArrayView1<'_, T>], w: &[Option<ArrayView
8983
if w_len != x_len {
9084
return Err(
9185
InvalidInputData(
92-
format!("The number of `weights` vectors ({}) is not equal to the number of `x` vectors ({})",
86+
format!("The number of `weights` vectors ({}) is not equal to the number of dimensions ({})",
9387
w_len, x_len)
9488
)
9589
)
@@ -115,17 +109,26 @@ pub(super) fn validate_weights<T>(x: &[ArrayView1<'_, T>], w: &[Option<ArrayView
115109
}
116110

117111

118-
pub(super) fn validate_smooth<T>(smooth: &[Option<T>]) -> Result<()>
112+
pub(super) fn validate_smooth<T>(x: &[ArrayView1<'_, T>], smooth: &[Option<T>]) -> Result<()>
119113
where
120114
T: NdFloat
121115
{
116+
let x_len = x.len();
117+
let s_len = smooth.len();
118+
119+
if s_len != x_len {
120+
return Err(
121+
InvalidInputData(
122+
format!("The number of `smooth` values ({}) is not equal to the number of dimensions ({})",
123+
s_len, x_len)
124+
)
125+
)
126+
}
127+
122128
for (ax, s_opt) in smooth.iter().enumerate() {
123129
if let Some(s) = s_opt {
124-
match validate_smooth_value(*s) {
125-
Ok(res) => (),
126-
Err(err) => {
127-
return Err(InvalidInputData(format!("{} for axis {}", err, ax)))
128-
}
130+
if let Err(err) = validate_smooth_value(*s) {
131+
return Err(InvalidInputData(format!("{} for axis {}", err, ax)))
129132
};
130133
}
131134
}

src/umv.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,12 @@ impl<'a, T, D> CubicSmoothingSpline<'a, T, D>
237237
self
238238
}
239239

240+
pub(crate) fn with_optional_weights(mut self, weights: Option<ArrayView1<'a, T>>) -> Self {
241+
self.invalidate();
242+
self.weights = weights;
243+
self
244+
}
245+
240246
/// Sets the smoothing parameter
241247
///
242248
/// The smoothing parameter should be in range `[0, 1]`,
@@ -251,6 +257,12 @@ impl<'a, T, D> CubicSmoothingSpline<'a, T, D>
251257
self
252258
}
253259

260+
pub(crate) fn with_optional_smooth(mut self, smooth: Option<T>) -> Self {
261+
self.invalidate();
262+
self.smooth = smooth;
263+
self
264+
}
265+
254266
/// Makes (computes) the spline for given data and parameters
255267
///
256268
/// # Errors

0 commit comments

Comments
 (0)