Skip to content

Commit 354976a

Browse files
Merge pull request #1361 from datajoint/fix/make-kwargs-tripartite-master
fix: Pass make_kwargs to make_fetch in tripartite pattern
2 parents 512af13 + 43bafd1 commit 354976a

File tree

2 files changed

+87
-3
lines changed

2 files changed

+87
-3
lines changed

src/datajoint/autopopulate.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def _rename_attributes(table, props):
204204
self._key_source *= _rename_attributes(*q)
205205
return self._key_source
206206

207-
def make(self, key: dict[str, Any]) -> None | Generator[Any, Any, None]:
207+
def make(self, key: dict[str, Any], **kwargs) -> None | Generator[Any, Any, None]:
208208
"""
209209
Compute and insert data for one key.
210210
@@ -219,6 +219,9 @@ def make(self, key: dict[str, Any]) -> None | Generator[Any, Any, None]:
219219
----------
220220
key : dict
221221
Primary key value identifying the entity to compute.
222+
**kwargs
223+
Keyword arguments passed from ``populate(make_kwargs=...)``.
224+
These are forwarded to ``make_fetch`` for the tripartite pattern.
222225
223226
Raises
224227
------
@@ -232,7 +235,7 @@ def make(self, key: dict[str, Any]) -> None | Generator[Any, Any, None]:
232235
233236
**Tripartite make**: For long-running computations, implement:
234237
235-
- ``make_fetch(key)``: Fetch data from parent tables
238+
- ``make_fetch(key, **kwargs)``: Fetch data from parent tables
236239
- ``make_compute(key, *fetched_data)``: Compute results
237240
- ``make_insert(key, *computed_result)``: Insert results
238241
@@ -250,7 +253,7 @@ def make(self, key: dict[str, Any]) -> None | Generator[Any, Any, None]:
250253
# User has implemented `_fetch`, `_compute`, and `_insert` methods instead
251254

252255
# Step 1: Fetch data from parent tables
253-
fetched_data = self.make_fetch(key) # fetched_data is a tuple
256+
fetched_data = self.make_fetch(key, **kwargs) # fetched_data is a tuple
254257
computed_result = yield fetched_data # passed as input into make_compute
255258

256259
# Step 2: If computed result is not passed in, compute the result

tests/integration/test_autopopulate.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,84 @@ def make(self, key):
147147
self.insert1(dict(key, crop_image=dict()))
148148

149149
Crop.populate()
150+
151+
152+
def test_make_kwargs_regular(prefix, connection_test):
153+
"""Test that make_kwargs are passed to regular make method."""
154+
schema = dj.Schema(f"{prefix}_make_kwargs_regular", connection=connection_test)
155+
156+
@schema
157+
class Source(dj.Lookup):
158+
definition = """
159+
source_id: int
160+
"""
161+
contents = [(1,), (2,)]
162+
163+
@schema
164+
class Computed(dj.Computed):
165+
definition = """
166+
-> Source
167+
---
168+
multiplier: int
169+
result: int
170+
"""
171+
172+
def make(self, key, multiplier=1):
173+
self.insert1(dict(key, multiplier=multiplier, result=key["source_id"] * multiplier))
174+
175+
# Test without make_kwargs
176+
Computed.populate(Source & "source_id = 1")
177+
assert (Computed & "source_id = 1").fetch1("result") == 1
178+
179+
# Test with make_kwargs
180+
Computed.populate(Source & "source_id = 2", make_kwargs={"multiplier": 10})
181+
assert (Computed & "source_id = 2").fetch1("multiplier") == 10
182+
assert (Computed & "source_id = 2").fetch1("result") == 20
183+
184+
185+
def test_make_kwargs_tripartite(prefix, connection_test):
186+
"""Test that make_kwargs are passed to make_fetch in tripartite pattern (issue #1350)."""
187+
schema = dj.Schema(f"{prefix}_make_kwargs_tripartite", connection=connection_test)
188+
189+
@schema
190+
class Source(dj.Lookup):
191+
definition = """
192+
source_id: int
193+
---
194+
value: int
195+
"""
196+
contents = [(1, 100), (2, 200)]
197+
198+
@schema
199+
class TripartiteComputed(dj.Computed):
200+
definition = """
201+
-> Source
202+
---
203+
scale: int
204+
result: int
205+
"""
206+
207+
def make_fetch(self, key, scale=1):
208+
"""Fetch data with optional scale parameter."""
209+
value = (Source & key).fetch1("value")
210+
return (value, scale)
211+
212+
def make_compute(self, key, value, scale):
213+
"""Compute result using fetched value and scale."""
214+
return (value * scale, scale)
215+
216+
def make_insert(self, key, result, scale):
217+
"""Insert computed result."""
218+
self.insert1(dict(key, scale=scale, result=result))
219+
220+
# Test without make_kwargs (scale defaults to 1)
221+
TripartiteComputed.populate(Source & "source_id = 1")
222+
row = (TripartiteComputed & "source_id = 1").fetch1()
223+
assert row["scale"] == 1
224+
assert row["result"] == 100 # 100 * 1
225+
226+
# Test with make_kwargs (scale = 5)
227+
TripartiteComputed.populate(Source & "source_id = 2", make_kwargs={"scale": 5})
228+
row = (TripartiteComputed & "source_id = 2").fetch1()
229+
assert row["scale"] == 5
230+
assert row["result"] == 1000 # 200 * 5

0 commit comments

Comments
 (0)