Skip to content

Commit c9ce954

Browse files
committed
feat: add --check-increase option
1 parent 2e5391d commit c9ce954

File tree

6 files changed

+217
-7
lines changed

6 files changed

+217
-7
lines changed

news/sort-squeezed-x.rst

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
**Added:**
2+
3+
* Add ``--check-increase`` option for squeeze morph.
4+
5+
**Changed:**
6+
7+
* <news item>
8+
9+
**Deprecated:**
10+
11+
* <news item>
12+
13+
**Removed:**
14+
15+
* <news item>
16+
17+
**Fixed:**
18+
19+
* <news item>
20+
21+
**Security:**
22+
23+
* <news item>

src/diffpy/morph/morph_io.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ def tabulate_results(multiple_morph_results):
408408
return tabulated_results
409409

410410

411-
def handle_warnings(squeeze_morph):
411+
def handle_extrapolation_warnings(squeeze_morph):
412412
if squeeze_morph is not None:
413413
extrapolation_info = squeeze_morph.extrapolation_info
414414
is_extrap_low = extrapolation_info["is_extrap_low"]
@@ -443,3 +443,25 @@ def handle_warnings(squeeze_morph):
443443
wmsg,
444444
UserWarning,
445445
)
446+
447+
448+
def handle_check_increase_warning(squeeze_morph):
449+
if squeeze_morph is not None:
450+
if squeeze_morph.squeeze_info["monotonic"]:
451+
wmsg = None
452+
else:
453+
overlapping_regions = squeeze_morph.squeeze_info[
454+
"overlapping_regions"
455+
]
456+
wmsg = (
457+
"Warning: The squeeze morph has interpolated your morphed "
458+
"function from a non-monotonically increasing grid. "
459+
"This can result in strange behavior in the regions "
460+
f"{overlapping_regions}. To disable this setting, "
461+
"please enable --check-increasing."
462+
)
463+
if wmsg:
464+
warnings.warn(
465+
wmsg,
466+
UserWarning,
467+
)

src/diffpy/morph/morphapp.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,16 @@ def custom_error(self, msg):
207207
"See online documentation for more information."
208208
),
209209
)
210+
group.add_option(
211+
"--check-increase",
212+
action="store_true",
213+
dest="check_increase",
214+
help=(
215+
"Disable squeeze morph to interpolat morphed function "
216+
"from a non-monotonically increasing grid."
217+
),
218+
)
219+
210220
group.add_option(
211221
"--smear",
212222
type="float",
@@ -571,7 +581,7 @@ def single_morph(
571581
except ValueError:
572582
parser.error(f"{coeff} could not be converted to float.")
573583
squeeze_poly_deg = len(squeeze_dict_in.keys())
574-
squeeze_morph = morphs.MorphSqueeze()
584+
squeeze_morph = morphs.MorphSqueeze(check_increase=opts.check_increase)
575585
chain.append(squeeze_morph)
576586
config["squeeze"] = squeeze_dict_in
577587
# config["extrap_index_low"] = None
@@ -701,8 +711,9 @@ def single_morph(
701711
chain(x_morph, y_morph, x_target, y_target)
702712

703713
# THROW ANY WARNINGS HERE
704-
io.handle_warnings(squeeze_morph)
705-
io.handle_warnings(shift_morph)
714+
io.handle_extrapolation_warnings(squeeze_morph)
715+
io.handle_check_increase_warning(squeeze_morph)
716+
io.handle_extrapolation_warnings(shift_morph)
706717

707718
# Get Rw for the morph range
708719
rw = tools.getRw(chain)

src/diffpy/morph/morphpy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __get_morph_opts__(parser, scale, stretch, smear, plot, **kwargs):
5151
"reverse",
5252
"diff",
5353
"get-diff",
54+
"check-increase",
5455
]
5556
opts_to_ignore = ["multiple-morphs", "multiple-targets"]
5657
for opt in opts_storing_values:

src/diffpy/morph/morphs/morphsqueeze.py

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Class MorphSqueeze -- Apply a polynomial to squeeze the morph
22
function."""
33

4+
import numpy
45
from numpy.polynomial import Polynomial
56
from scipy.interpolate import CubicSpline
67

@@ -68,8 +69,68 @@ class MorphSqueeze(Morph):
6869
squeeze_cutoff_low = None
6970
squeeze_cutoff_high = None
7071

71-
def __init__(self, config=None):
72+
def __init__(self, config=None, check_increase=False):
7273
super().__init__(config)
74+
self.check_increase = check_increase
75+
76+
def _set_squeeze_info(self, x, x_sorted):
77+
self.squeeze_info = {"monotonic": True, "overlapping_regions": None}
78+
if list(x) != list(x_sorted):
79+
if self.check_increase:
80+
raise ValueError(
81+
"Squeezed grid is not strictly increasing."
82+
"Please (1) decrease the order of your polynomial and "
83+
"(2) ensure that the initial polynomial morph result in "
84+
"good agreement between your reference and "
85+
"objective functions."
86+
)
87+
else:
88+
overlapping_regions = self._get_overlapping_regions(x)
89+
self.squeeze_info["monotonic"] = False
90+
self.squeeze_info["overlapping_regions"] = overlapping_regions
91+
92+
def _sort_squeeze(self, x, y):
93+
"""Sort x,y according to the value of x."""
94+
xy = list(zip(x, y))
95+
xy_sorted = sorted(xy, key=lambda pair: pair[0])
96+
x_sorted, y_sorted = list(zip(*xy_sorted))
97+
return x_sorted, y_sorted
98+
99+
def _get_overlapping_regions(self, x):
100+
diffx = numpy.diff(x)
101+
monotomic_regions = []
102+
monotomic_signs = [numpy.sign(diffx[0])]
103+
current_region = [x[0], x[1]]
104+
for i in range(1, len(diffx)):
105+
if numpy.sign(diffx[i]) == monotomic_signs[-1]:
106+
current_region.append(x[i + 1])
107+
else:
108+
monotomic_regions.append(current_region)
109+
monotomic_signs.append(diffx[i])
110+
current_region = [x[i + 1]]
111+
monotomic_regions.append(current_region)
112+
overlapping_regions_sign = -1 if x[0] < x[-1] else 1
113+
overlapping_regions_x = [
114+
monotomic_regions[i]
115+
for i in range(len(monotomic_regions))
116+
if monotomic_signs[i] == overlapping_regions_sign
117+
]
118+
overlapping_regions = [
119+
(min(region), max(region)) for region in overlapping_regions_x
120+
]
121+
return overlapping_regions
122+
123+
def _handle_duplicates(self, x, y):
124+
"""Remove duplicated x and use the mean value of y corresponded
125+
to the duplicated x."""
126+
unq_x, unq_inv = numpy.unique(x, return_inverse=True)
127+
if len(unq_x) == len(x):
128+
return x, y
129+
else:
130+
y_avg = numpy.zeros_like(unq_x)
131+
for i in range(len(unq_x)):
132+
y_avg[i] = numpy.array(y)[unq_inv == i].mean()
133+
return unq_x, y_avg
73134

74135
def morph(self, x_morph, y_morph, x_target, y_target):
75136
"""Apply a polynomial to squeeze the morph function.
@@ -82,9 +143,16 @@ def morph(self, x_morph, y_morph, x_target, y_target):
82143
coeffs = [self.squeeze[f"a{i}"] for i in range(len(self.squeeze))]
83144
squeeze_polynomial = Polynomial(coeffs)
84145
x_squeezed = self.x_morph_in + squeeze_polynomial(self.x_morph_in)
85-
self.y_morph_out = CubicSpline(x_squeezed, self.y_morph_in)(
146+
x_squeezed_sorted, y_morph_sorted = self._sort_squeeze(
147+
x_squeezed, self.y_morph_in
148+
)
149+
self._set_squeeze_info(x_squeezed_sorted, x_squeezed)
150+
x_squeezed_sorted, y_morph_sorted = self._handle_duplicates(
151+
x_squeezed_sorted, y_morph_sorted
152+
)
153+
self.y_morph_out = CubicSpline(x_squeezed_sorted, y_morph_sorted)(
86154
self.x_morph_in
87155
)
88-
self.set_extrapolation_info(x_squeezed, self.x_morph_in)
156+
self.set_extrapolation_info(x_squeezed_sorted, self.x_morph_in)
89157

90158
return self.xyallout

tests/test_morphsqueeze.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,91 @@ def test_morphsqueeze_extrapolate(
173173
single_morph(parser, opts, pargs, stdout_flag=False)
174174

175175

176+
@pytest.mark.parametrize(
177+
"squeeze_coeffs, x_morph",
178+
[
179+
({"a0": -1, "a1": -1, "a2": 2}, np.linspace(-1, 1, 101)),
180+
],
181+
)
182+
def test_sort_squeeze_bad(user_filesystem, squeeze_coeffs, x_morph):
183+
# call in .py without --check-increase
184+
x_target = x_morph
185+
y_target = np.sin(x_target)
186+
coeffs = [squeeze_coeffs[f"a{i}"] for i in range(len(squeeze_coeffs))]
187+
squeeze_polynomial = Polynomial(coeffs)
188+
x_squeezed = x_morph + squeeze_polynomial(x_morph)
189+
y_morph = np.sin(x_squeezed)
190+
morph = MorphSqueeze()
191+
morph.squeeze = squeeze_coeffs
192+
with pytest.warns() as w:
193+
morphpy.morph_arrays(
194+
np.array([x_morph, y_morph]).T,
195+
np.array([x_target, y_target]).T,
196+
squeeze=coeffs,
197+
apply=True,
198+
)
199+
assert len(w) == 1
200+
assert w[0].category is UserWarning
201+
actual_wmsg = str(w[0].message)
202+
expected_wmsg = (
203+
"Warning: The squeeze morph has interpolated your morphed "
204+
"function from a non-monotonically increasing grid. "
205+
)
206+
assert expected_wmsg in actual_wmsg
207+
208+
# call in .py with --check-increase
209+
with pytest.raises(ValueError) as excinfo:
210+
morphpy.morph_arrays(
211+
np.array([x_morph, y_morph]).T,
212+
np.array([x_target, y_target]).T,
213+
squeeze=coeffs,
214+
check_increase=True,
215+
apply=True,
216+
)
217+
actual_emsg = str(excinfo.value)
218+
expected_emsg = "Squeezed grid is not strictly increasing."
219+
assert expected_emsg in actual_emsg
220+
221+
# call in CLI without --check-increase
222+
morph_file, target_file = create_morph_data_file(
223+
user_filesystem / "cwd_dir", x_morph, y_morph, x_target, y_target
224+
)
225+
parser = create_option_parser()
226+
(opts, pargs) = parser.parse_args(
227+
[
228+
"--squeeze",
229+
",".join(map(str, coeffs)),
230+
f"{morph_file.as_posix()}",
231+
f"{target_file.as_posix()}",
232+
"--apply",
233+
"-n",
234+
]
235+
)
236+
with pytest.warns(UserWarning) as w:
237+
single_morph(parser, opts, pargs, stdout_flag=False)
238+
assert len(w) == 1
239+
actual_wmsg = str(w[0].message)
240+
assert expected_wmsg in actual_wmsg
241+
242+
# call in CLI with --check-increase
243+
parser = create_option_parser()
244+
(opts, pargs) = parser.parse_args(
245+
[
246+
"--squeeze",
247+
",".join(map(str, coeffs)),
248+
f"{morph_file.as_posix()}",
249+
f"{target_file.as_posix()}",
250+
"--apply",
251+
"-n",
252+
"--check-increase",
253+
]
254+
)
255+
with pytest.raises(ValueError) as excinfo:
256+
single_morph(parser, opts, pargs, stdout_flag=False)
257+
actual_emsg = str(excinfo.value)
258+
assert expected_emsg in actual_emsg
259+
260+
176261
def create_morph_data_file(
177262
data_dir_path, x_morph, y_morph, x_target, y_target
178263
):

0 commit comments

Comments
 (0)