Skip to content

Commit 7db3a4f

Browse files
committed
Refactor __eq__ function for DO
1 parent 6a9419a commit 7db3a4f

File tree

1 file changed

+50
-44
lines changed

1 file changed

+50
-44
lines changed

src/diffpy/utils/diffraction_objects.py

Lines changed: 50 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,15 @@
1414
XQUANTITIES = ANGLEQUANTITIES + DQUANTITIES + QQUANTITIES
1515
XUNITS = ["degrees", "radians", "rad", "deg", "inv_angs", "inv_nm", "nm-1", "A-1"]
1616

17-
x_grid_emsg = (
18-
"objects are not on the same x-grid. You may add them using the self.add method "
19-
"and specifying how to handle the mismatch."
17+
x_grid_length_mismatch_emsg = (
18+
"The two objects have different x-array lengths. "
19+
"Please ensure the length of the x-value during initialization is identical."
20+
)
21+
22+
invalid_add_type_emsg = (
23+
"You may only add a DiffractionObject with another DiffractionObject or a scalar value. "
24+
"Please rerun by adding another DiffractionObject instance or a scalar value. "
25+
"e.g., my_do_1 + my_do_2 or my_do + 10"
2026
)
2127

2228

@@ -169,32 +175,44 @@ def __eq__(self, other):
169175
return True
170176

171177
def __add__(self, other):
172-
summed = deepcopy(self)
173-
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray):
174-
summed.on_tth[1] = self.on_tth[1] + other
175-
summed.on_q[1] = self.on_q[1] + other
176-
elif not isinstance(other, DiffractionObject):
177-
raise TypeError("I only know how to sum two DiffractionObject objects")
178-
elif self.on_tth[0].all() != other.on_tth[0].all():
179-
raise RuntimeError(x_grid_emsg)
180-
else:
181-
summed.on_tth[1] = self.on_tth[1] + other.on_tth[1]
182-
summed.on_q[1] = self.on_q[1] + other.on_q[1]
183-
return summed
178+
"""Add a scalar value or another DiffractionObject to the xarrays of
179+
the DiffractionObject.
184180
185-
def __radd__(self, other):
186-
summed = deepcopy(self)
187-
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray):
188-
summed.on_tth[1] = self.on_tth[1] + other
189-
summed.on_q[1] = self.on_q[1] + other
190-
elif not isinstance(other, DiffractionObject):
191-
raise TypeError("I only know how to sum two Scattering_object objects")
192-
elif self.on_tth[0].all() != other.on_tth[0].all():
193-
raise RuntimeError(x_grid_emsg)
181+
Parameters
182+
----------
183+
other : DiffractionObject or int or float
184+
The object to add to the current DiffractionObject. If `other` is a scalar value,
185+
it will be added to all xarrays. The length of the xarrays must match if `other` is
186+
an instance of DiffractionObject.
187+
188+
Returns
189+
-------
190+
DiffractionObject
191+
The new and deep-copied DiffractionObject instance after adding values to the xarrays.
192+
193+
Raises
194+
------
195+
ValueError
196+
Raised when the length of the xarrays of the two DiffractionObject instances do not match.
197+
TypeError
198+
Raised when the type of `other` is not an instance of DiffractionObject, int, or float.
199+
"""
200+
summed_do = deepcopy(self)
201+
# Add scalar value to all xarrays by broadcasting
202+
if isinstance(other, (int, float)):
203+
summed_do._all_arrays[:, 1] += other
204+
summed_do._all_arrays[:, 2] += other
205+
summed_do._all_arrays[:, 3] += other
206+
# Add xarrays of two DiffractionObject instances
207+
elif isinstance(other, DiffractionObject):
208+
if len(self.on_tth()[0]) != len(other.on_tth()[0]):
209+
raise ValueError(x_grid_length_mismatch_emsg)
210+
summed_do._all_arrays[:, 1] += other.on_q()[0]
211+
summed_do._all_arrays[:, 2] += other.on_tth()[0]
212+
summed_do._all_arrays[:, 3] += other.on_d()[0]
194213
else:
195-
summed.on_tth[1] = self.on_tth[1] + other.on_tth[1]
196-
summed.on_q[1] = self.on_q[1] + other.on_q[1]
197-
return summed
214+
raise TypeError(invalid_add_type_emsg)
215+
return summed_do
198216

199217
def __sub__(self, other):
200218
subtracted = deepcopy(self)
@@ -204,7 +222,7 @@ def __sub__(self, other):
204222
elif not isinstance(other, DiffractionObject):
205223
raise TypeError("I only know how to subtract two Scattering_object objects")
206224
elif self.on_tth[0].all() != other.on_tth[0].all():
207-
raise RuntimeError(x_grid_emsg)
225+
raise RuntimeError(x_grid_length_mismatch_emsg)
208226
else:
209227
subtracted.on_tth[1] = self.on_tth[1] - other.on_tth[1]
210228
subtracted.on_q[1] = self.on_q[1] - other.on_q[1]
@@ -218,7 +236,7 @@ def __rsub__(self, other):
218236
elif not isinstance(other, DiffractionObject):
219237
raise TypeError("I only know how to subtract two Scattering_object objects")
220238
elif self.on_tth[0].all() != other.on_tth[0].all():
221-
raise RuntimeError(x_grid_emsg)
239+
raise RuntimeError(x_grid_length_mismatch_emsg)
222240
else:
223241
subtracted.on_tth[1] = other.on_tth[1] - self.on_tth[1]
224242
subtracted.on_q[1] = other.on_q[1] - self.on_q[1]
@@ -232,19 +250,7 @@ def __mul__(self, other):
232250
elif not isinstance(other, DiffractionObject):
233251
raise TypeError("I only know how to multiply two Scattering_object objects")
234252
elif self.on_tth[0].all() != other.on_tth[0].all():
235-
raise RuntimeError(x_grid_emsg)
236-
else:
237-
multiplied.on_tth[1] = self.on_tth[1] * other.on_tth[1]
238-
multiplied.on_q[1] = self.on_q[1] * other.on_q[1]
239-
return multiplied
240-
241-
def __rmul__(self, other):
242-
multiplied = deepcopy(self)
243-
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray):
244-
multiplied.on_tth[1] = other * self.on_tth[1]
245-
multiplied.on_q[1] = other * self.on_q[1]
246-
elif self.on_tth[0].all() != other.on_tth[0].all():
247-
raise RuntimeError(x_grid_emsg)
253+
raise RuntimeError(x_grid_length_mismatch_emsg)
248254
else:
249255
multiplied.on_tth[1] = self.on_tth[1] * other.on_tth[1]
250256
multiplied.on_q[1] = self.on_q[1] * other.on_q[1]
@@ -258,7 +264,7 @@ def __truediv__(self, other):
258264
elif not isinstance(other, DiffractionObject):
259265
raise TypeError("I only know how to multiply two Scattering_object objects")
260266
elif self.on_tth[0].all() != other.on_tth[0].all():
261-
raise RuntimeError(x_grid_emsg)
267+
raise RuntimeError(x_grid_length_mismatch_emsg)
262268
else:
263269
divided.on_tth[1] = self.on_tth[1] / other.on_tth[1]
264270
divided.on_q[1] = self.on_q[1] / other.on_q[1]
@@ -270,7 +276,7 @@ def __rtruediv__(self, other):
270276
divided.on_tth[1] = other / self.on_tth[1]
271277
divided.on_q[1] = other / self.on_q[1]
272278
elif self.on_tth[0].all() != other.on_tth[0].all():
273-
raise RuntimeError(x_grid_emsg)
279+
raise RuntimeError(x_grid_length_mismatch_emsg)
274280
else:
275281
divided.on_tth[1] = other.on_tth[1] / self.on_tth[1]
276282
divided.on_q[1] = other.on_q[1] / self.on_q[1]

0 commit comments

Comments
 (0)