Skip to content

Commit 3d4841b

Browse files
committed
Add __eq__ news and test
1 parent 7db3a4f commit 3d4841b

File tree

3 files changed

+102
-3
lines changed

3 files changed

+102
-3
lines changed

news/add-operations-tests.rst

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
**Added:**
2+
3+
* unit tests for __add__ operation for DiffractionObject
4+
5+
**Changed:**
6+
7+
* <news item>
8+
9+
**Deprecated:**
10+
11+
* <news item>
12+
13+
**Removed:**
14+
15+
* <news item>
16+
17+
**Fixed:**
18+
19+
* <news item>
20+
21+
**Security:**
22+
23+
* <news item>

tests/conftest.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,20 @@ def invalid_q_or_d_or_wavelength_error_msg():
6363
"The supplied input array and wavelength will result in an impossible two-theta. "
6464
"Please check these values and re-instantiate the DiffractionObject with correct values."
6565
)
66+
67+
68+
@pytest.fixture
69+
def invalid_add_type_error_msg():
70+
return (
71+
"You may only add a DiffractionObject with another DiffractionObject or a scalar value. "
72+
"Please rerun by adding another DiffractionObject instance or a scalar value. "
73+
"e.g., my_do_1 + my_do_2 or my_do + 10"
74+
)
75+
76+
77+
@pytest.fixture
78+
def x_grid_size_mismatch_error_msg():
79+
return (
80+
"The two objects have different x-array lengths. "
81+
"Please ensure the length of the x-value during initialization is identical."
82+
)

tests/test_diffraction_objects.py

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def test_diffraction_objects_equality(
155155

156156

157157
@pytest.mark.parametrize(
158-
"xtype, expected_xarray",
158+
"xtype, expected_all_arrays",
159159
[
160160
# Test whether on_xtype returns the correct xarray values.
161161
# C1: tth to tth, expect no change in xarray value
@@ -169,10 +169,10 @@ def test_diffraction_objects_equality(
169169
("d", np.array([12.13818, 6.28319])),
170170
],
171171
)
172-
def test_on_xtype(xtype, expected_xarray, do_minimal_tth):
172+
def test_on_xtype(xtype, expected_all_arrays, do_minimal_tth):
173173
do = do_minimal_tth
174174
actual_xrray, actual_yarray = do.on_xtype(xtype)
175-
assert np.allclose(actual_xrray, expected_xarray)
175+
assert np.allclose(actual_xrray, expected_all_arrays)
176176
assert np.allclose(actual_yarray, np.array([1, 2]))
177177

178178

@@ -702,3 +702,62 @@ def test_copy_object(do_minimal):
702702
do_copy = do.copy()
703703
assert do == do_copy
704704
assert id(do) != id(do_copy)
705+
706+
707+
@pytest.mark.parametrize(
708+
"starting_all_arrays, scalar_value, expected_all_arrays",
709+
[
710+
# Test scalar addition to xarray values (q, tth, d) and expect no change to yarray values
711+
( # C1: Add integer of 5, expect xarray to increase by by 5
712+
np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),
713+
5,
714+
np.array([[1.0, 5.51763809, 35.0, 17.13818192], [2.0, 6.0, 65.0, 11.28318531]]),
715+
),
716+
( # C2: Add float of 5.1, expect xarray to be added by 5.1
717+
np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),
718+
5.1,
719+
np.array([[1.0, 5.61763809, 35.1, 17.23818192], [2.0, 6.1, 65.1, 11.38318531]]),
720+
),
721+
],
722+
)
723+
def test_addition_operator_by_scalar(starting_all_arrays, scalar_value, expected_all_arrays, do_minimal_tth):
724+
do = do_minimal_tth
725+
assert np.allclose(do.all_arrays, starting_all_arrays)
726+
do_sum = do + scalar_value
727+
assert np.allclose(do_sum.all_arrays, expected_all_arrays)
728+
729+
730+
@pytest.mark.parametrize(
731+
"LHS_all_arrays, RHS_all_arrays, expected_all_arrays_sum",
732+
[
733+
# Test addition of two DO objects, expect combined xarray values (q, tth, d) and no change to yarray
734+
( # C1: Add two DO objects with identical xarray values, expect sum of xarray values
735+
(np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),),
736+
(np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),),
737+
np.array([[1.0, 1.03527618, 60.0, 24.27636384], [2.0, 2.0, 120.0, 12.56637061]]),
738+
),
739+
],
740+
)
741+
def test_addition_operator_by_another_do(LHS_all_arrays, RHS_all_arrays, expected_all_arrays_sum, do_minimal_tth):
742+
assert np.allclose(do_minimal_tth.all_arrays, LHS_all_arrays)
743+
do_LHS = do_minimal_tth
744+
do_RHS = do_minimal_tth
745+
do_sum = do_LHS + do_RHS
746+
assert np.allclose(do_LHS.all_arrays, LHS_all_arrays)
747+
assert np.allclose(do_RHS.all_arrays, RHS_all_arrays)
748+
assert np.allclose(do_sum.all_arrays, expected_all_arrays_sum)
749+
750+
751+
def test_addition_operator_invalid_type(do_minimal_tth, invalid_add_type_error_msg):
752+
# Add a string to a DO object, expect TypeError, only scalar (int, float) allowed for addition
753+
do_LHS = do_minimal_tth
754+
with pytest.raises(TypeError, match=re.escape(invalid_add_type_error_msg)):
755+
do_LHS + "string_value"
756+
757+
758+
def test_addition_operator_invalid_xarray_length(do_minimal, do_minimal_tth, x_grid_size_mismatch_error_msg):
759+
# Combine two DO objects, one with empty xarrays (do_minimal) and the other with non-empty xarrays
760+
do_LHS = do_minimal
761+
do_RHS = do_minimal_tth
762+
with pytest.raises(ValueError, match=re.escape(x_grid_size_mismatch_error_msg)):
763+
do_LHS + do_RHS

0 commit comments

Comments
 (0)