Skip to content

Commit e35af9c

Browse files
committed
feat: add checkExtrapolation function in morph.Morph
1 parent dce54cf commit e35af9c

File tree

7 files changed

+76
-32
lines changed

7 files changed

+76
-32
lines changed

news/extrap-warnings.rst

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
**Added:**
2+
3+
* Enable ``diffpy.morph`` to detect extrapolation.
4+
5+
**Changed:**
6+
7+
* <news item>
8+
9+
**Deprecated:**
10+
11+
* <news item>
12+
13+
**Removed:**
14+
15+
* <news item>
16+
17+
**Fixed:**
18+
19+
* <news item>
20+
21+
**Security:**
22+
23+
* <news item>

src/diffpy/morph/morph_io.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -410,27 +410,30 @@ def tabulate_results(multiple_morph_results):
410410

411411
def handle_warnings(squeeze_morph):
412412
if squeeze_morph is not None:
413-
eil = squeeze_morph.extrap_index_low
414-
eih = squeeze_morph.extrap_index_high
415-
416-
if eil is not None or eih is not None:
417-
if eih is None:
413+
extrapolation_info = squeeze_morph.extrapolation_info
414+
is_extrap_low = extrapolation_info["is_extrap_low"]
415+
is_extrap_high = extrapolation_info["is_extrap_high"]
416+
cutoff_low = extrapolation_info["cutoff_low"]
417+
cutoff_high = extrapolation_info["cutoff_high"]
418+
419+
if is_extrap_low or is_extrap_high:
420+
if not is_extrap_high:
418421
wmsg = (
419422
"Warning: points with grid value below "
420-
f"{squeeze_morph.squeeze_cutoff_low} "
423+
f"{cutoff_low} "
421424
f"will be extrapolated."
422425
)
423-
elif eil is None:
426+
elif not is_extrap_low:
424427
wmsg = (
425428
"Warning: points with grid value above "
426-
f"{squeeze_morph.squeeze_cutoff_high} "
429+
f"{cutoff_high} "
427430
f"will be extrapolated."
428431
)
429432
else:
430433
wmsg = (
431434
"Warning: points with grid value below "
432-
f"{squeeze_morph.squeeze_cutoff_low} and above "
433-
f"{squeeze_morph.squeeze_cutoff_high} "
435+
f"{cutoff_low} and above "
436+
f"{cutoff_high} "
434437
f"will be extrapolated."
435438
)
436439
warnings.warn(

src/diffpy/morph/morphapp.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -610,10 +610,12 @@ def single_morph(
610610
config["smear"] = smear_in
611611
# Shift
612612
# Only enable hshift is squeeze is not enabled
613+
shift_morph = None
613614
if (
614615
opts.hshift is not None and squeeze_poly_deg < 0
615616
) or opts.vshift is not None:
616-
chain.append(morphs.MorphShift())
617+
shift_morph = morphs.MorphShift()
618+
chain.append(shift_morph)
617619
if opts.hshift is not None and squeeze_poly_deg < 0:
618620
hshift_in = opts.hshift
619621
config["hshift"] = hshift_in
@@ -700,6 +702,7 @@ def single_morph(
700702

701703
# THROW ANY WARNINGS HERE
702704
io.handle_warnings(squeeze_morph)
705+
io.handle_warnings(shift_morph)
703706

704707
# Get Rw for the morph range
705708
rw = tools.getRw(chain)

src/diffpy/morph/morphs/morph.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
# See LICENSE.txt for license information.
1313
#
1414
##############################################################################
15-
"""Morph -- base class for defining a morph.
16-
"""
15+
"""Morph -- base class for defining a morph."""
1716

1817

1918
LABEL_RA = "r (A)" # r-grid
@@ -246,6 +245,23 @@ def plotOutputs(self, xylabels=True, **plotargs):
246245
ylabel(self.youtlabel)
247246
return rv
248247

248+
def checkExtrapolation(self, x_true, x_extrapolate):
249+
import numpy
250+
251+
cutoff_low = min(x_true)
252+
cutoff_high = max(x_true)
253+
low_extrap = numpy.where(x_extrapolate < cutoff_low)[0]
254+
high_extrap = numpy.where(x_extrapolate > cutoff_high)[0]
255+
is_extrap_low = False if len(low_extrap) == 0 else True
256+
is_extrap_high = False if len(high_extrap) == 0 else True
257+
extrapolation_info = {
258+
"is_extrap_low": is_extrap_low,
259+
"cutoff_low": cutoff_low,
260+
"is_extrap_high": is_extrap_high,
261+
"cutoff_high": cutoff_high,
262+
}
263+
return extrapolation_info
264+
249265
def __getattr__(self, name):
250266
"""Obtain the value from self.config, when normal lookup fails.
251267

src/diffpy/morph/morphs/morphshift.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def morph(self, x_morph, y_morph, x_target, y_target):
5757
r = self.x_morph_in - hshift
5858
self.y_morph_out = numpy.interp(r, self.x_morph_in, self.y_morph_in)
5959
self.y_morph_out += vshift
60+
self.extrapolation_info = self.checkExtrapolation(self.x_morph_in, r)
6061
return self.xyallout
6162

6263

src/diffpy/morph/morphs/morphsqueeze.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Class MorphSqueeze -- Apply a polynomial to squeeze the morph
22
function."""
33

4-
import numpy as np
54
from numpy.polynomial import Polynomial
65
from scipy.interpolate import CubicSpline
76

@@ -83,14 +82,11 @@ def morph(self, x_morph, y_morph, x_target, y_target):
8382
coeffs = [self.squeeze[f"a{i}"] for i in range(len(self.squeeze))]
8483
squeeze_polynomial = Polynomial(coeffs)
8584
x_squeezed = self.x_morph_in + squeeze_polynomial(self.x_morph_in)
86-
self.squeeze_cutoff_low = min(x_squeezed)
87-
self.squeeze_cutoff_high = max(x_squeezed)
8885
self.y_morph_out = CubicSpline(x_squeezed, self.y_morph_in)(
8986
self.x_morph_in
9087
)
91-
low_extrap = np.where(self.x_morph_in < self.squeeze_cutoff_low)[0]
92-
high_extrap = np.where(self.x_morph_in > self.squeeze_cutoff_high)[0]
93-
self.extrap_index_low = low_extrap[-1] if low_extrap.size else None
94-
self.extrap_index_high = high_extrap[0] if high_extrap.size else None
88+
self.extrapolation_info = self.checkExtrapolation(
89+
x_squeezed, self.x_morph_in
90+
)
9591

9692
return self.xyallout

tests/test_morphsqueeze.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,23 +46,27 @@
4646
@pytest.mark.parametrize("squeeze_coeffs", squeeze_coeffs_dic)
4747
def test_morphsqueeze(x_morph, x_target, squeeze_coeffs):
4848
y_target = np.sin(x_target)
49+
y_morph = np.sin(x_morph)
50+
# expected output
4951
coeffs = [squeeze_coeffs[f"a{i}"] for i in range(len(squeeze_coeffs))]
5052
squeeze_polynomial = Polynomial(coeffs)
5153
x_squeezed = x_morph + squeeze_polynomial(x_morph)
52-
y_morph = np.sin(x_squeezed)
53-
low_extrap = np.where(x_morph < x_squeezed[0])[0]
54-
high_extrap = np.where(x_morph > x_squeezed[-1])[0]
55-
extrap_index_low_expected = low_extrap[-1] if low_extrap.size else None
56-
extrap_index_high_expected = high_extrap[0] if high_extrap.size else None
54+
y_morph_expected = y_morph
5755
x_morph_expected = x_morph
58-
y_morph_expected = np.sin(x_morph)
56+
x_target_expected = x_target
57+
y_target_expected = y_target
58+
# actual output
5959
morph = MorphSqueeze()
60+
y_morph = np.sin(x_squeezed)
6061
morph.squeeze = squeeze_coeffs
6162
x_morph_actual, y_morph_actual, x_target_actual, y_target_actual = morph(
6263
x_morph, y_morph, x_target, y_target
6364
)
64-
extrap_index_low = morph.extrap_index_low
65-
extrap_index_high = morph.extrap_index_high
65+
66+
extrap_low = np.where(x_morph < min(x_squeezed))[0]
67+
extrap_high = np.where(x_morph > max(x_squeezed))[0]
68+
extrap_index_low = extrap_low[-1] if extrap_low.size else None
69+
extrap_index_high = extrap_high[0] if extrap_high.size else None
6670
if extrap_index_low is None:
6771
extrap_index_low = 0
6872
elif extrap_index_high is None:
@@ -82,11 +86,9 @@ def test_morphsqueeze(x_morph, x_target, squeeze_coeffs):
8286
y_morph_expected[extrap_index_high:],
8387
atol=1e-3,
8488
)
85-
assert morph.extrap_index_low == extrap_index_low_expected
86-
assert morph.extrap_index_high == extrap_index_high_expected
8789
assert np.allclose(x_morph_actual, x_morph_expected)
88-
assert np.allclose(x_target_actual, x_target)
89-
assert np.allclose(y_target_actual, y_target)
90+
assert np.allclose(x_target_actual, x_target_expected)
91+
assert np.allclose(y_target_actual, y_target_expected)
9092

9193

9294
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)