Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions news/array_index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
**Added:**

* function to return the index of the closest value to the specified value in an array.

**Changed:**

* <news item>

**Deprecated:**

* <news item>

**Removed:**

* <news item>

**Fixed:**

* <news item>

**Security:**

* <news item>
32 changes: 23 additions & 9 deletions src/diffpy/utils/diffraction_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,15 +248,29 @@ def _set_array_from_range(self, begin, end, step_size=None, n_steps=None):
array = np.linspace(begin, end, n_steps)
return array

def get_angle_index(self, angle):
count = 0
for i, target in enumerate(self.angles):
if angle == target:
return i
else:
count += 1
if count >= len(self.angles):
raise IndexError(f"WARNING: no angle {angle} found in angles list")
def get_array_index(self, value, xtype=None):
"""
returns the index of the closest value in the array associated with the specified xtype

Parameters
----------
xtype str
the xtype used to access the array
value float
the target value to search for

Returns
-------
the index of the value in the array
"""

if xtype is None:
xtype = self.input_xtype
array = self.on_xtype(xtype)[0]
if len(array) == 0:
raise ValueError(f"The '{xtype}' array is empty. Please ensure it is initialized.")
i = (np.abs(array - value)).argmin()
return i

def _set_xarrays(self, xarray, xtype):
self.all_arrays = np.empty(shape=(len(xarray), 4))
Expand Down
26 changes: 26 additions & 0 deletions tests/test_diffraction_objects.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from pathlib import Path

import numpy as np
Expand Down Expand Up @@ -211,6 +212,31 @@ def _test_valid_diffraction_objects(actual_diffraction_object, function, expecte
return np.allclose(actual_array, expected_array)


params_index = [
# UC1: exact match
([4 * np.pi, np.array([30.005, 60]), np.array([1, 2]), "tth", "tth", 30.005], [0]),
# UC2: target value lies in the array, returns the (first) closest index
([4 * np.pi, np.array([30, 60]), np.array([1, 2]), "tth", "tth", 45], [0]),
([4 * np.pi, np.array([30, 60]), np.array([1, 2]), "tth", "q", 0.25], [0]),
# UC3: target value out of the range, returns the closest index
([4 * np.pi, np.array([0.25, 0.5, 0.71]), np.array([1, 2, 3]), "q", "q", 0.1], [0]),
([4 * np.pi, np.array([30, 60]), np.array([1, 2]), "tth", "tth", 63], [1]),
]


@pytest.mark.parametrize("inputs, expected", params_index)
def test_get_array_index(inputs, expected):
test = DiffractionObject(wavelength=inputs[0], xarray=inputs[1], yarray=inputs[2], xtype=inputs[3])
actual = test.get_array_index(value=inputs[5], xtype=inputs[4])
assert actual == expected[0]


def test_get_array_index_bad():
test = DiffractionObject(wavelength=2 * np.pi, xarray=np.array([]), yarray=np.array([]), xtype="tth")
with pytest.raises(ValueError, match=re.escape("The 'tth' array is empty. Please ensure it is initialized.")):
test.get_array_index(value=30)


def test_dump(tmp_path, mocker):
x, y = np.linspace(0, 5, 6), np.linspace(0, 5, 6)
directory = Path(tmp_path)
Expand Down
Loading