Skip to content

Commit 4ea1701

Browse files
committed
add validate data for ndg make
1 parent 13a8496 commit 4ea1701

File tree

4 files changed

+128
-26
lines changed

4 files changed

+128
-26
lines changed

src/ndg.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ impl<'a, T, D> NdGridCubicSmoothingSpline<'a, T, D>
185185
/// - If the data or parameters are invalid
186186
///
187187
pub fn make(mut self) -> Result<Self> {
188+
self.make_validate()?;
188189
self.make_spline()?;
189190
Ok(self)
190191
}

src/ndg/validate.rs

Lines changed: 97 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,41 @@ use ndarray::{
88
use almost::AlmostEqual;
99

1010
use crate::{Result, CsapsError::InvalidInputData};
11-
use crate::validate::validate_data_sites;
11+
use crate::validate::{validate_data_sites, validate_smooth_value};
1212

13+
use super::NdGridCubicSmoothingSpline;
1314

14-
pub(super) fn validate_xy<'a, T, D>(x: &'a [ArrayView1<'a, T>], y: ArrayView<'a, T, D>) -> Result<()>
15-
where T: NdFloat + AlmostEqual, D: Dimension
15+
16+
impl<'a, T, D> NdGridCubicSmoothingSpline<'a, T, D>
17+
where
18+
T: NdFloat + AlmostEqual + Default,
19+
D: Dimension
20+
{
21+
pub(super) fn make_validate(&self) -> Result<()> {
22+
validate_xy(&self.x, self.y.view())?;
23+
24+
if let Some(weights) = self.weights.as_ref() {
25+
validate_weights(&self.x, weights)?;
26+
}
27+
28+
if let Some(smooth) = self.smooth.as_ref() {
29+
validate_smooth(smooth)?;
30+
}
31+
32+
Ok(())
33+
}
34+
}
35+
36+
37+
pub(super) fn validate_xy<T, D>(x: &[ArrayView1<'_, T>], y: ArrayView<'_, T, D>) -> Result<()>
38+
where
39+
T: NdFloat + AlmostEqual,
40+
D: Dimension
1641
{
1742
if x.len() != y.ndim() {
1843
return Err(
1944
InvalidInputData(
20-
format!("The number of `X` data sites ({}) is not equal to `Y` data dimensionality ({})",
45+
format!("The number of `x` data sites ({}) is not equal to `y` data dimensionality ({})",
2146
x.len(), y.ndim())
2247
)
2348
)
@@ -28,17 +53,82 @@ pub(super) fn validate_xy<'a, T, D>(x: &'a [ArrayView1<'a, T>], y: ArrayView<'a,
2853
.zip(y.shape().iter())
2954
.enumerate()
3055
{
56+
let xi_len = xi.len();
57+
58+
if xi_len < 2 {
59+
return Err(
60+
InvalidInputData(
61+
format!("The size of `x` site vectors must be greater or equal to 2, axis {}", ax)
62+
)
63+
)
64+
}
65+
3166
validate_data_sites(xi.view())?;
3267

33-
if xi.len() != ys {
68+
if xi_len != ys {
3469
return Err(
3570
InvalidInputData(
36-
format!("`X` data sites vector size ({}) is not equal to `Y` data size ({}) for axis {}",
37-
xi.len(), ys, ax)
71+
format!("`x` data sites vector size ({}) is not equal to `y` data size ({}) for axis {}",
72+
xi_len, ys, ax)
3873
)
3974
)
4075
}
4176
}
4277

4378
Ok(())
4479
}
80+
81+
82+
pub(super) fn validate_weights<T>(x: &[ArrayView1<'_, T>], w: &[Option<ArrayView1<'_, T>>]) -> Result<()>
83+
where
84+
T: NdFloat + AlmostEqual
85+
{
86+
let x_len = x.len();
87+
let w_len = w.len();
88+
89+
if w_len != x_len {
90+
return Err(
91+
InvalidInputData(
92+
format!("The number of `weights` vectors ({}) is not equal to the number of `x` vectors ({})",
93+
w_len, x_len)
94+
)
95+
)
96+
}
97+
98+
for (ax, (xi, wi)) in x.iter().zip(w.iter()).enumerate() {
99+
if let Some(wi_view) = wi {
100+
let xi_len = xi.len();
101+
let wi_len = wi_view.len();
102+
103+
if wi_len != xi_len {
104+
return Err(
105+
InvalidInputData(
106+
format!("`weights` vector size ({}) is not equal to `x` vector size ({}) for axis {}",
107+
wi_len, xi_len, ax)
108+
)
109+
)
110+
}
111+
}
112+
}
113+
114+
Ok(())
115+
}
116+
117+
118+
pub(super) fn validate_smooth<T>(smooth: &[Option<T>]) -> Result<()>
119+
where
120+
T: NdFloat
121+
{
122+
for (ax, s_opt) in smooth.iter().enumerate() {
123+
if let Some(s) = s_opt {
124+
match validate_smooth_value(*s) {
125+
Ok(res) => (),
126+
Err(err) => {
127+
return Err(InvalidInputData(format!("{} for axis {}", err, ax)))
128+
}
129+
};
130+
}
131+
}
132+
133+
Ok(())
134+
}

src/umv/validate.rs

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,26 @@ use crate::{
1212
CubicSmoothingSpline,
1313
CsapsError::InvalidInputData,
1414
Result,
15-
validate::validate_data_sites,
15+
validate::{validate_data_sites, validate_smooth_value},
1616
};
1717

1818

1919
impl<'a, T, D> CubicSmoothingSpline<'a, T, D>
2020
where T: NdFloat + Default + AlmostEqual, D: Dimension
2121
{
2222
pub(super) fn make_validate_data(&self) -> Result<()> {
23+
let x_size = self.x.len();
24+
25+
if x_size < 2 {
26+
return Err(
27+
InvalidInputData(
28+
"The size of data vectors must be greater or equal to 2".to_string()
29+
)
30+
)
31+
}
32+
33+
validate_data_sites(self.x)?;
34+
2335
if self.y.ndim() == 0 {
2436
return Err(
2537
InvalidInputData("`y` has zero dimensionality".to_string())
@@ -38,7 +50,6 @@ impl<'a, T, D> CubicSmoothingSpline<'a, T, D>
3850
)
3951
}
4052

41-
let x_size = self.x.len();
4253
let y_size = self.y.len_of(axis);
4354

4455
if x_size != y_size {
@@ -50,16 +61,6 @@ impl<'a, T, D> CubicSmoothingSpline<'a, T, D>
5061
)
5162
}
5263

53-
if x_size < 2 {
54-
return Err(
55-
InvalidInputData(
56-
"The size of data vectors must be greater or equal to 2".to_string()
57-
)
58-
)
59-
}
60-
61-
validate_data_sites(self.x)?;
62-
6364
if let Some(weights) = self.weights {
6465
let w_size = weights.len();
6566

@@ -73,13 +74,7 @@ impl<'a, T, D> CubicSmoothingSpline<'a, T, D>
7374
}
7475

7576
if let Some(smooth) = self.smooth {
76-
if smooth < T::zero() || smooth > T::one() {
77-
return Err(
78-
InvalidInputData(
79-
format!("`smooth` value must be in range 0..1, given {:?}", smooth)
80-
)
81-
)
82-
}
77+
validate_smooth_value(smooth)?;
8378
}
8479

8580
Ok(())

src/validate.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,19 @@ pub(crate) fn validate_data_sites<T>(x: ArrayView1<T>) -> Result<()>
2121

2222
Ok(())
2323
}
24+
25+
26+
pub(crate) fn validate_smooth_value<T>(smooth: T) -> Result<()>
27+
where
28+
T: NdFloat
29+
{
30+
if smooth < T::zero() || smooth > T::one() {
31+
return Err(
32+
InvalidInputData(
33+
format!("`smooth` value must be in range 0..1, given {:?}", smooth)
34+
)
35+
)
36+
}
37+
38+
Ok(())
39+
}

0 commit comments

Comments
 (0)