Skip to content

Commit 9c8fd95

Browse files
committed
Fix another divide by zero warning in transform
1 parent 25cb875 commit 9c8fd95

File tree

2 files changed

+95
-114
lines changed

2 files changed

+95
-114
lines changed

tests/test_diffraction_objects.py

Lines changed: 83 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -403,114 +403,101 @@ def test_dump(tmp_path, mocker):
403403
assert actual == expected
404404

405405

406-
test_init_valid_params = [
407-
( # instantiate just array attributes
408-
{
409-
"xarray": np.array([0.0, 90.0, 180.0]),
410-
"yarray": np.array([1.0, 2.0, 3.0]),
411-
"xtype": "tth",
412-
"wavelength": 4.0 * np.pi,
413-
},
414-
{
415-
"_all_arrays": np.array(
416-
[
417-
[1.0, 0.0, 0.0, np.float64(np.inf)],
418-
[2.0, 1.0 / np.sqrt(2), 90.0, np.sqrt(2) * 2 * np.pi],
419-
[3.0, 1.0, 180.0, 1.0 * 2 * np.pi],
420-
]
421-
),
422-
"metadata": {},
423-
"_input_xtype": "tth",
424-
"name": "",
425-
"scat_quantity": "",
426-
"qmin": np.float64(0.0),
427-
"qmax": np.float64(1.0),
428-
"tthmin": np.float64(0.0),
429-
"tthmax": np.float64(180.0),
430-
"dmin": np.float64(2 * np.pi),
431-
"dmax": np.float64(np.inf),
432-
"wavelength": 4.0 * np.pi,
433-
},
434-
),
435-
( # instantiate just array attributes
436-
{
437-
"xarray": np.array([np.inf, 2 * np.sqrt(2) * np.pi, 2 * np.pi]),
438-
"yarray": np.array([1.0, 2.0, 3.0]),
439-
"xtype": "d",
440-
"wavelength": 4.0 * np.pi,
441-
"scat_quantity": "x-ray",
442-
},
443-
{
444-
"_all_arrays": np.array(
445-
[
446-
[1.0, 0.0, 0.0, np.float64(np.inf)],
447-
[2.0, 1.0 / np.sqrt(2), 90.0, np.sqrt(2) * 2 * np.pi],
448-
[3.0, 1.0, 180.0, 1.0 * 2 * np.pi],
449-
]
450-
),
451-
"metadata": {},
452-
"_input_xtype": "d",
453-
"name": "",
454-
"scat_quantity": "x-ray",
455-
"qmin": np.float64(0.0),
456-
"qmax": np.float64(1.0),
457-
"tthmin": np.float64(0.0),
458-
"tthmax": np.float64(180.0),
459-
"dmin": np.float64(2 * np.pi),
460-
"dmax": np.float64(np.inf),
461-
"wavelength": 4.0 * np.pi,
462-
},
463-
),
464-
]
465-
466-
467406
@pytest.mark.parametrize(
468-
"init_args, expected_do_dict",
469-
test_init_valid_params,
407+
"do_init_args, expected_do_dict",
408+
[
409+
( # instantiate just array attributes
410+
{
411+
"xarray": np.array([0.0, 90.0, 180.0]),
412+
"yarray": np.array([1.0, 2.0, 3.0]),
413+
"xtype": "tth",
414+
"wavelength": 4.0 * np.pi,
415+
},
416+
{
417+
"_all_arrays": np.array(
418+
[
419+
[1.0, 0.0, 0.0, np.float64(np.inf)],
420+
[2.0, 1.0 / np.sqrt(2), 90.0, np.sqrt(2) * 2 * np.pi],
421+
[3.0, 1.0, 180.0, 1.0 * 2 * np.pi],
422+
]
423+
),
424+
"metadata": {},
425+
"_input_xtype": "tth",
426+
"name": "",
427+
"scat_quantity": "",
428+
"qmin": np.float64(0.0),
429+
"qmax": np.float64(1.0),
430+
"tthmin": np.float64(0.0),
431+
"tthmax": np.float64(180.0),
432+
"dmin": np.float64(2 * np.pi),
433+
"dmax": np.float64(np.inf),
434+
"wavelength": 4.0 * np.pi,
435+
},
436+
),
437+
( # instantiate just array attributes
438+
{
439+
"xarray": np.array([np.inf, 2 * np.sqrt(2) * np.pi, 2 * np.pi]),
440+
"yarray": np.array([1.0, 2.0, 3.0]),
441+
"xtype": "d",
442+
"wavelength": 4.0 * np.pi,
443+
"scat_quantity": "x-ray",
444+
},
445+
{
446+
"_all_arrays": np.array(
447+
[
448+
[1.0, 0.0, 0.0, np.float64(np.inf)],
449+
[2.0, 1.0 / np.sqrt(2), 90.0, np.sqrt(2) * 2 * np.pi],
450+
[3.0, 1.0, 180.0, 1.0 * 2 * np.pi],
451+
]
452+
),
453+
"metadata": {},
454+
"_input_xtype": "d",
455+
"name": "",
456+
"scat_quantity": "x-ray",
457+
"qmin": np.float64(0.0),
458+
"qmax": np.float64(1.0),
459+
"tthmin": np.float64(0.0),
460+
"tthmax": np.float64(180.0),
461+
"dmin": np.float64(2 * np.pi),
462+
"dmax": np.float64(np.inf),
463+
"wavelength": 4.0 * np.pi,
464+
},
465+
),
466+
],
470467
)
471-
def test_init_valid(init_args, expected_do_dict):
472-
actual_do_dict = DiffractionObject(**init_args).__dict__
468+
def test_init_valid(do_init_args, expected_do_dict):
469+
actual_do_dict = DiffractionObject(**do_init_args).__dict__
473470
diff = DeepDiff(
474471
actual_do_dict, expected_do_dict, ignore_order=True, significant_digits=13, exclude_paths="root['_id']"
475472
)
476473
assert diff == {}
477474

478475

479-
test_init_invalid_params = [
480-
( # UC1: no arguments provided
481-
{},
482-
"missing 3 required positional arguments: 'xarray', 'yarray', and 'xtype'",
483-
),
484-
( # UC2: only xarray and yarray provided
485-
{"xarray": np.array([0.0, 90.0]), "yarray": np.array([0.0, 90.0])},
486-
"missing 1 required positional argument: 'xtype'",
487-
),
488-
]
489-
490-
491-
@pytest.mark.parametrize("init_args, expected_error_msg", test_init_invalid_params)
476+
@pytest.mark.parametrize(
477+
"do_init_args, expected_error_msg",
478+
[
479+
( # Case 1: no arguments provided
480+
{},
481+
"missing 3 required positional arguments: 'xarray', 'yarray', and 'xtype'",
482+
),
483+
( # Case 2: only xarray and yarray provided
484+
{"xarray": np.array([0.0, 90.0]), "yarray": np.array([0.0, 90.0])},
485+
"missing 1 required positional argument: 'xtype'",
486+
),
487+
],
488+
)
492489
def test_init_invalid_args(
493-
init_args,
490+
do_init_args,
494491
expected_error_msg,
495492
):
496493
with pytest.raises(TypeError, match=expected_error_msg):
497-
DiffractionObject(**init_args)
494+
DiffractionObject(**do_init_args)
498495

499496

500-
def test_all_array_getter():
501-
actual_do = DiffractionObject(
502-
xarray=np.array([0.0, 90.0, 180.0]),
503-
yarray=np.array([1.0, 2.0, 3.0]),
504-
xtype="tth",
505-
wavelength=4.0 * np.pi,
506-
)
507-
expected_all_arrays = np.array(
508-
[
509-
[1.0, 0.0, 0.0, np.float64(np.inf)],
510-
[2.0, 1.0 / np.sqrt(2), 90.0, np.sqrt(2) * 2 * np.pi],
511-
[3.0, 1.0, 180.0, 1.0 * 2 * np.pi],
512-
]
513-
)
497+
def test_all_array_getter(do_minimal_tth):
498+
actual_do = do_minimal_tth
499+
print(actual_do.all_arrays)
500+
expected_all_arrays = [[1, 0.51763809, 30, 12.13818192], [2, 1, 60, 6.28318531]]
514501
assert np.allclose(actual_do.all_arrays, expected_all_arrays)
515502

516503

@@ -575,14 +562,8 @@ def test_input_xtype_setter_error(do_minimal):
575562
do.input_xtype = "q"
576563

577564

578-
def test_copy_object():
579-
do = DiffractionObject(
580-
name="test",
581-
wavelength=4.0 * np.pi,
582-
xarray=np.array([0.0, 90.0, 180.0]),
583-
yarray=np.array([1.0, 2.0, 3.0]),
584-
xtype="tth",
585-
)
565+
def test_copy_object(do_minimal):
566+
do = do_minimal
586567
do_copy = do.copy()
587568
assert do == do_copy
588569
assert id(do) != id(do_copy)

tests/test_transforms.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -146,43 +146,43 @@ def test_q_to_d(q, expected_d, warning_expected):
146146

147147

148148
@pytest.mark.parametrize(
149-
"d, expected_q",
149+
"d, expected_q, zero_divide_error_expected",
150150
[
151151
# UC1: User specified empty d values
152-
(np.array([]), np.array([])),
152+
(np.array([]), np.array([]), False),
153153
# UC2: User specified valid d values
154154
(
155155
np.array([5 * np.pi, 4 * np.pi, 3 * np.pi, 2 * np.pi, np.pi, 0]),
156156
np.array([0.4, 0.5, 0.66667, 1, 2, np.inf]),
157+
True,
157158
),
158159
],
159160
)
160-
def test_d_to_q(d, expected_q):
161-
actual_q = d_to_q(d)
161+
def test_d_to_q(d, expected_q, zero_divide_error_expected):
162+
if zero_divide_error_expected:
163+
with pytest.warns(RuntimeWarning, match="divide by zero encountered in divide"):
164+
actual_q = d_to_q(d)
165+
else:
166+
actual_q = d_to_q(d)
162167
assert np.allclose(actual_q, expected_q)
163168

164169

165170
@pytest.mark.parametrize(
166171
"wavelength, tth, expected_d, divide_by_zero_warning_expected",
167-
[
172+
[
168173
# Test conversion of q to d with valid values
169174
# Case 1: empty tth values, no, expect empty d values
170175
(None, np.array([]), np.array([]), False),
171176
# Case 2: empty tth values, wavelength provided, expect empty d values
172177
(4 * np.pi, np.array([]), np.array([]), False),
173178
# Case 3: User specified valid tth values between 0-180 degrees (without wavelength)
174-
(
175-
None,
176-
np.array([0, 30, 60, 90, 120, 180]),
177-
np.array([0, 1, 2, 3, 4, 5]),
178-
False
179-
),
179+
(None, np.array([0, 30, 60, 90, 120, 180]), np.array([0, 1, 2, 3, 4, 5]), False),
180180
# Case 4: User specified valid tth values between 0-180 degrees (with wavelength)
181181
(
182182
4 * np.pi,
183183
np.array([0, 30.0, 60.0, 90.0, 120.0, 180.0]),
184184
np.array([np.inf, 24.27636, 12.56637, 8.88577, 7.25520, 6.28319]),
185-
True
185+
True,
186186
),
187187
],
188188
)

0 commit comments

Comments
 (0)