11"""Class MorphSqueeze -- Apply a polynomial to squeeze the morph
22function."""
33
4+ import numpy
45from numpy .polynomial import Polynomial
56from scipy .interpolate import CubicSpline
67
@@ -68,8 +69,68 @@ class MorphSqueeze(Morph):
6869 squeeze_cutoff_low = None
6970 squeeze_cutoff_high = None
7071
71- def __init__ (self , config = None ):
72+ def __init__ (self , config = None , check_increase = False ):
7273 super ().__init__ (config )
74+ self .check_increase = check_increase
75+
76+ def _set_squeeze_info (self , x , x_sorted ):
77+ self .squeeze_info = {"monotonic" : True , "overlapping_regions" : None }
78+ if list (x ) != list (x_sorted ):
79+ if self .check_increase :
80+ raise ValueError (
81+ "Squeezed grid is not strictly increasing."
82+ "Please (1) decrease the order of your polynomial and "
83+ "(2) ensure that the initial polynomial morph result in "
84+ "good agreement between your reference and "
85+ "objective functions."
86+ )
87+ else :
88+ overlapping_regions = self ._get_overlapping_regions (x )
89+ self .squeeze_info ["monotonic" ] = False
90+ self .squeeze_info ["overlapping_regions" ] = overlapping_regions
91+
92+ def _sort_squeeze (self , x , y ):
93+ """Sort x,y according to the value of x."""
94+ xy = list (zip (x , y ))
95+ xy_sorted = sorted (xy , key = lambda pair : pair [0 ])
96+ x_sorted , y_sorted = list (zip (* xy_sorted ))
97+ return x_sorted , y_sorted
98+
99+ def _get_overlapping_regions (self , x ):
100+ diffx = numpy .diff (x )
101+ monotomic_regions = []
102+ monotomic_signs = [numpy .sign (diffx [0 ])]
103+ current_region = [x [0 ], x [1 ]]
104+ for i in range (1 , len (diffx )):
105+ if numpy .sign (diffx [i ]) == monotomic_signs [- 1 ]:
106+ current_region .append (x [i + 1 ])
107+ else :
108+ monotomic_regions .append (current_region )
109+ monotomic_signs .append (diffx [i ])
110+ current_region = [x [i + 1 ]]
111+ monotomic_regions .append (current_region )
112+ overlapping_regions_sign = - 1 if x [0 ] < x [- 1 ] else 1
113+ overlapping_regions_x = [
114+ monotomic_regions [i ]
115+ for i in range (len (monotomic_regions ))
116+ if monotomic_signs [i ] == overlapping_regions_sign
117+ ]
118+ overlapping_regions = [
119+ (min (region ), max (region )) for region in overlapping_regions_x
120+ ]
121+ return overlapping_regions
122+
123+ def _handle_duplicates (self , x , y ):
124+ """Remove duplicated x and use the mean value of y corresponded
125+ to the duplicated x."""
126+ unq_x , unq_inv = numpy .unique (x , return_inverse = True )
127+ if len (unq_x ) == len (x ):
128+ return x , y
129+ else :
130+ y_avg = numpy .zeros_like (unq_x )
131+ for i in range (len (unq_x )):
132+ y_avg [i ] = numpy .array (y )[unq_inv == i ].mean ()
133+ return unq_x , y_avg
73134
74135 def morph (self , x_morph , y_morph , x_target , y_target ):
75136 """Apply a polynomial to squeeze the morph function.
@@ -82,9 +143,16 @@ def morph(self, x_morph, y_morph, x_target, y_target):
82143 coeffs = [self .squeeze [f"a{ i } " ] for i in range (len (self .squeeze ))]
83144 squeeze_polynomial = Polynomial (coeffs )
84145 x_squeezed = self .x_morph_in + squeeze_polynomial (self .x_morph_in )
85- self .y_morph_out = CubicSpline (x_squeezed , self .y_morph_in )(
146+ x_squeezed_sorted , y_morph_sorted = self ._sort_squeeze (
147+ x_squeezed , self .y_morph_in
148+ )
149+ self ._set_squeeze_info (x_squeezed_sorted , x_squeezed )
150+ x_squeezed_sorted , y_morph_sorted = self ._handle_duplicates (
151+ x_squeezed_sorted , y_morph_sorted
152+ )
153+ self .y_morph_out = CubicSpline (x_squeezed_sorted , y_morph_sorted )(
86154 self .x_morph_in
87155 )
88- self .set_extrapolation_info (x_squeezed , self .x_morph_in )
156+ self .set_extrapolation_info (x_squeezed_sorted , self .x_morph_in )
89157
90158 return self .xyallout
0 commit comments