Skip to content

Commit 8d28b75

Browse files
Merge remote-tracking branch 'origin/maint/2.0'
2 parents 9720e26 + da0cc0c commit 8d28b75

File tree

2 files changed

+84
-1
lines changed

2 files changed

+84
-1
lines changed

src/datajoint/expression.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -729,12 +729,46 @@ def fetch(
729729

730730
# Handle specific attributes requested
731731
if attrs:
732+
# Check for special 'KEY' attribute
733+
def is_key(attr):
734+
return attr == "KEY"
735+
736+
has_key = any(is_key(a) for a in attrs)
737+
738+
# Handle fetch('KEY') alone - return list of primary key dicts
739+
if has_key and len(attrs) == 1:
740+
return list(self.keys(order_by=order_by, limit=limit, offset=offset))
741+
732742
if as_dict is True:
733743
# fetch('col1', 'col2', as_dict=True) -> list of dicts
734-
return self.proj(*attrs).to_dicts(order_by=order_by, limit=limit, offset=offset, squeeze=squeeze)
744+
# Replace KEY with primary key columns
745+
proj_attrs = []
746+
for attr in attrs:
747+
if is_key(attr):
748+
proj_attrs.extend(self.primary_key)
749+
else:
750+
proj_attrs.append(attr)
751+
return self.proj(*proj_attrs).to_dicts(order_by=order_by, limit=limit, offset=offset, squeeze=squeeze)
735752
else:
736753
# fetch('col1', 'col2') or fetch('col1', 'col2', as_dict=False) -> tuple of arrays
737754
# This matches DJ 1.x behavior where fetch('col') returns array(['alpha', 'beta'])
755+
if has_key:
756+
# Need to handle KEY specially - it returns list of dicts, not array
757+
proj_attrs = []
758+
for attr in attrs:
759+
if is_key(attr):
760+
proj_attrs.extend(self.primary_key)
761+
else:
762+
proj_attrs.append(attr)
763+
dicts = self.proj(*proj_attrs).to_dicts(order_by=order_by, limit=limit, offset=offset, squeeze=squeeze)
764+
# Build result, with KEY returning list of dicts
765+
results = []
766+
for attr in attrs:
767+
if is_key(attr):
768+
results.append([{k: d[k] for k in self.primary_key} for d in dicts])
769+
else:
770+
results.append(np.array([d[attr] for d in dicts]))
771+
return results[0] if len(attrs) == 1 else tuple(results)
738772
return self.to_arrays(*attrs, order_by=order_by, limit=limit, offset=offset, squeeze=squeeze)
739773

740774
# Handle as_dict=True -> to_dicts()

tests/integration/test_fetch.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,3 +457,52 @@ def test_to_arrays_inhomogeneous_shapes_second_axis(schema_any):
457457
assert data[0].shape == (100,)
458458
assert data[1].shape == (1, 100)
459459
assert data[2].shape == (2, 100)
460+
461+
462+
def test_fetch_KEY(lang, languages):
463+
"""Test fetch('KEY') returns list of primary key dicts.
464+
465+
Regression test for https://github.com/datajoint/datajoint-python/issues/1381
466+
"""
467+
import warnings
468+
469+
# Suppress deprecation warning for fetch
470+
with warnings.catch_warnings():
471+
warnings.simplefilter("ignore", DeprecationWarning)
472+
473+
# fetch('KEY') should return list of primary key dicts
474+
keys = lang.fetch("KEY")
475+
assert isinstance(keys, list)
476+
assert len(keys) == len(languages)
477+
assert all(isinstance(k, dict) for k in keys)
478+
# Primary key is (name, language)
479+
assert all(set(k.keys()) == {"name", "language"} for k in keys)
480+
481+
482+
def test_fetch1_KEY(lang):
483+
"""Test fetch1('KEY') returns primary key dict.
484+
485+
Regression test for https://github.com/datajoint/datajoint-python/issues/1381
486+
"""
487+
key = {"name": "Edgar", "language": "Japanese"}
488+
result = (lang & key).fetch1("KEY")
489+
assert isinstance(result, dict)
490+
assert result == key
491+
492+
493+
def test_fetch_KEY_with_other_attrs(lang):
494+
"""Test fetch('KEY', 'name') returns (keys_list, name_array).
495+
496+
Regression test for https://github.com/datajoint/datajoint-python/issues/1381
497+
"""
498+
import warnings
499+
500+
with warnings.catch_warnings():
501+
warnings.simplefilter("ignore", DeprecationWarning)
502+
503+
# fetch('KEY', 'name') should return tuple of (list of dicts, array)
504+
keys, names = lang.fetch("KEY", "name")
505+
assert isinstance(keys, list)
506+
assert all(isinstance(k, dict) for k in keys)
507+
assert isinstance(names, np.ndarray)
508+
assert len(keys) == len(names)

0 commit comments

Comments
 (0)