Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions src/diffpy/utils/diffraction_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,25 +248,27 @@ def _set_array_from_range(self, begin, end, step_size=None, n_steps=None):
array = np.linspace(begin, end, n_steps)
return array

def get_angle_index(self, angle):
def get_array_index(self, xtype, value):
"""
returns the index of a given angle in the angles list
returns the index of a given value in the array associated with the specified xtype

Parameters
----------
angle float
the angle to search for
xtype str
the xtype used to access the array
value float
the target value to search for

Returns
-------
the index of the angle in the angles list
the index of the value in the array
"""
if not hasattr(self, "angles"):
self.angles = np.array([])
for i, target in enumerate(self.angles):
if angle == target:
if self.on_xtype(xtype) is None:
raise ValueError(_xtype_wmsg(xtype))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I can remove this if I let on_xtype() to raise an error for invalid xtypes. Will do that in the other PR.

for i, target in enumerate(self.on_xtype(xtype)[0]):
if value == target:
return i
raise IndexError(f"WARNING: no angle {angle} found in angles list.")
raise IndexError(f"WARNING: no matching value {value} found in the {xtype} array.")

def _set_xarrays(self, xarray, xtype):
self.all_arrays = np.empty(shape=(len(xarray), 4))
Expand Down
49 changes: 35 additions & 14 deletions tests/test_diffraction_objects.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import re
from pathlib import Path

import numpy as np
import pytest
from freezegun import freeze_time

from diffpy.utils.diffraction_objects import DiffractionObject
from diffpy.utils.diffraction_objects import XQUANTITIES, DiffractionObject
from diffpy.utils.transforms import wavelength_warning_emsg


Expand Down Expand Up @@ -212,21 +213,41 @@ def _test_valid_diffraction_objects(actual_diffraction_object, function, expecte


def test_get_angle_index():
test = DiffractionObject()
test.angles = np.array([10, 20, 30, 40, 50, 60])
actual_angle_index = test.get_angle_index(angle=10)
assert actual_angle_index == 0
test = DiffractionObject(
wavelength=0.71, xarray=np.array([30, 60, 90]), yarray=np.array([1, 2, 3]), xtype="tth"
)
actual_index = test.get_array_index(xtype="tth", value=30)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need more cases. Here everything is integers and there is a match. What do we want to happen if the value lies between two other values? Return nearest and a warning? What if it lies outside the range of values? Return nearest and a warning? What if it is really far away? Have a threshold after which we raise and error?

Copy link
Contributor Author

@yucongalicechen yucongalicechen Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sbillinge I addressed this in the new commit. please review.

assert actual_index == 0


params_index_bad = [
# UC1: empty array
(
[0.71, np.array([]), np.array([]), "tth", "tth", 10],
[IndexError, "WARNING: no matching value 10 found in the tth array."],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this an error or a warning?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be an error. I've edited the error message.

),
# UC2: invalid xtype
(
[None, np.array([]), np.array([]), "tth", "invalid", 10],
[
ValueError,
f"WARNING: I don't know how to handle the xtype, 'invalid'. "
f"Please rerun specifying an xtype from {*XQUANTITIES, }",
],
),
# UC3: pre-defined array with non-matching value
(
[0.71, np.array([30, 60, 90]), np.array([1, 2, 3]), "tth", "q", 30],
[IndexError, "WARNING: no matching value 30 found in the q array."],
),
]


def test_get_angle_index_bad():
test = DiffractionObject()
# empty angles list
with pytest.raises(IndexError, match="WARNING: no angle 11 found in angles list."):
test.get_angle_index(angle=11)
# pre-defined angles list
test.angles = np.array([10, 20, 30, 40, 50, 60])
with pytest.raises(IndexError, match="WARNING: no angle 11 found in angles list."):
test.get_angle_index(angle=11)
@pytest.mark.parametrize("inputs, expected", params_index_bad)
def test_get_angle_index_bad(inputs, expected):
test = DiffractionObject(wavelength=inputs[0], xarray=inputs[1], yarray=inputs[2], xtype=inputs[3])
with pytest.raises(expected[0], match=re.escape(expected[1])):
test.get_array_index(xtype=inputs[4], value=inputs[5])


def test_dump(tmp_path, mocker):
Expand Down
Loading