diff --git a/src/datajoint/autopopulate.py b/src/datajoint/autopopulate.py index 24d6b17aa..8f0946a06 100644 --- a/src/datajoint/autopopulate.py +++ b/src/datajoint/autopopulate.py @@ -89,6 +89,46 @@ class AutoPopulate: _key_source = None _allow_insert = False _jobs = None + _upstream = None # set per-make() by _populate_one; see `upstream` property below + + @property + def upstream(self): + """ + Pre-restricted ancestor view for the current ``make(self, key)`` call. + + Inside ``make()``, ``self.upstream`` is a ``Diagram`` constructed via + :meth:`Diagram.trace(self & key) `. Use + ``self.upstream[T]`` to obtain a pre-restricted ``QueryExpression`` + (or ``FreeTable``, when indexed by a string) for any ancestor of + ``self``. + + Reading via ``self.upstream`` is the provenance-safe pattern: the + framework guarantees the restriction matches the current ``key``, + and indexing a non-ancestor table raises ``DataJointError``. See + :doc:`reference/specs/provenance` for the contract. + + Raises + ------ + DataJointError + If accessed outside ``make()`` execution. To construct a trace + explicitly, use ``dj.Diagram.trace(self & key)``. + + Examples + -------- + :: + + def make(self, key): + date = self.upstream[Session].fetch1("session_date") + traces = self.upstream[ExtractTraces].to_arrays("trace") + self.insert1({**key, "summary": compute(traces, date)}) + """ + if self._upstream is None: + raise DataJointError( + "self.upstream is only available inside make(). " + "Outside make(), construct a trace explicitly: " + "dj.Diagram.trace(self & key)." + ) + return self._upstream class _JobsDescriptor: """Descriptor allowing jobs access on both class and instance.""" @@ -611,6 +651,13 @@ def _populate1( logger.jobs(f"Making {key} -> {self.full_table_name}") self.__class__._allow_insert = True + # Pre-construct the upstream view for this make() call. Lazy — only + # `dj.Diagram.trace(self & key)` runs here (graph copy); the + # expensive SQL fetch fires when the user accesses self.upstream[T]. + from .diagram import Diagram + + self._upstream = Diagram.trace(self & dict(key)) + try: if not is_generator: make(dict(key), **(make_kwargs or {})) @@ -668,6 +715,10 @@ def _populate1( return True finally: self.__class__._allow_insert = False + # Clear the per-make() upstream view so subsequent attribute + # access raises a clear error rather than silently using a + # stale trace from the previous make() call. + self._upstream = None def progress(self, *restrictions: Any, display: bool = False) -> tuple[int, int]: """ diff --git a/tests/integration/test_autopopulate.py b/tests/integration/test_autopopulate.py index 02ba69d6b..0f7c60b5c 100644 --- a/tests/integration/test_autopopulate.py +++ b/tests/integration/test_autopopulate.py @@ -354,6 +354,245 @@ def make_insert(self, key, result, scale): assert row["result"] == 1000 # 200 * 5 +# ========================================================================= +# #1424: self.upstream pre-restricted ancestor access in make() +# ========================================================================= + + +def test_upstream_provides_pre_restricted_ancestor(prefix, connection_test): + """make() can read self.upstream[Ancestor] and get pre-restricted data.""" + schema = dj.Schema(f"{prefix}_upstream_basic", connection=connection_test) + + @schema + class Subject(dj.Lookup): + definition = """ + subject_id : int32 + --- + name : varchar(64) + """ + contents = [(1, "alice"), (2, "bob")] + + @schema + class Greeting(dj.Computed): + definition = """ + -> Subject + --- + greeting : varchar(128) + """ + + def make(self, key): + # Provenance-safe read: self.upstream pre-restricted to current key + name = self.upstream[Subject].fetch1("name") + self.insert1({**key, "greeting": f"Hello, {name}!"}) + + Greeting.populate() + assert (Greeting & {"subject_id": 1}).fetch1("greeting") == "Hello, alice!" + assert (Greeting & {"subject_id": 2}).fetch1("greeting") == "Hello, bob!" + + +def test_upstream_rejects_non_ancestor(prefix, connection_test): + """self.upstream[T] for a non-ancestor table raises inside make().""" + schema = dj.Schema(f"{prefix}_upstream_non_ancestor", connection=connection_test) + + @schema + class Subject(dj.Lookup): + definition = """ + subject_id : int32 + """ + contents = [(1,)] + + @schema + class Unrelated(dj.Lookup): + definition = """ + u_id : int32 + """ + contents = [(99,)] + + captured_errors: list[Exception] = [] + + @schema + class Bad(dj.Computed): + definition = """ + -> Subject + --- + ok : tinyint + """ + + def make(self, key): + # class-form lookup (Diagram.__getitem__ class branch) + try: + self.upstream[Unrelated] + except DataJointError as exc: + captured_errors.append(("class", exc)) + # string-form lookup (the separate FreeTable/string branch) + try: + self.upstream[Unrelated.full_table_name] + except DataJointError as exc: + captured_errors.append(("string", exc)) + # Insert anyway so populate doesn't fail + self.insert1({**key, "ok": 1}) + + Bad.populate() + # Both the class-form and string-form lookups must reject the non-ancestor. + forms = {form for form, _ in captured_errors} + assert forms == {"class", "string"}, f"expected both branches to raise, got {forms}" + class_err = next(exc for form, exc in captured_errors if form == "class") + assert "not in this trace" in str(class_err).lower() + + +def test_upstream_unset_outside_make(prefix, connection_test): + """Accessing self.upstream outside of make() raises a clear error.""" + schema = dj.Schema(f"{prefix}_upstream_outside_make", connection=connection_test) + + @schema + class Source(dj.Lookup): + definition = """ + source_id : int32 + """ + contents = [(1,)] + + @schema + class Derived(dj.Computed): + definition = """ + -> Source + --- + val : int32 + """ + + def make(self, key): + self.insert1({**key, "val": 0}) + + with pytest.raises(DataJointError, match="only available inside make"): + Derived().upstream + + +def test_upstream_cleared_after_make(prefix, connection_test): + """After make() completes, the SAME instance that ran make() has its + self.upstream cleared. Capturing the populate instance is what gives this + test teeth: it would FAIL if the `finally: self._upstream = None` line were + removed (a fresh-instance probe would pass regardless).""" + schema = dj.Schema(f"{prefix}_upstream_cleared", connection=connection_test) + + @schema + class Source(dj.Lookup): + definition = """ + source_id : int32 + """ + contents = [(1,)] + + captured = [] + + @schema + class Derived(dj.Computed): + definition = """ + -> Source + --- + val : int32 + """ + + def make(self, key): + captured.append(self) # the actual populate instance + assert self._upstream is not None # set for the duration of make() + self.insert1({**key, "val": 0}) + + Derived.populate() + assert captured, "make() did not run" + inst = captured[0] + # The finally block must have cleared _upstream on this very instance. + assert inst._upstream is None + with pytest.raises(DataJointError, match="only available inside make"): + inst.upstream + + +def test_upstream_cleared_after_make_raises(prefix, connection_test): + """The reset lives in `finally` specifically so it survives an exception in + make(). Force make() to raise and assert the populate instance's + self.upstream is still cleared.""" + schema = dj.Schema(f"{prefix}_upstream_exc", connection=connection_test) + + @schema + class Source(dj.Lookup): + definition = """ + source_id : int32 + """ + contents = [(1,)] + + captured = [] + + @schema + class Boom(dj.Computed): + definition = """ + -> Source + --- + val : int32 + """ + + def make(self, key): + captured.append(self) + assert self._upstream is not None + raise RuntimeError("make failed on purpose") + + with pytest.raises(RuntimeError, match="make failed on purpose"): + Boom.populate(suppress_errors=False) + assert captured, "make() did not run" + inst = captured[0] + # Cleared by the finally block even though make() raised. + assert inst._upstream is None + with pytest.raises(DataJointError, match="only available inside make"): + inst.upstream + + +def test_upstream_seen_across_tripartite_make(prefix, connection_test): + """The tripartite make() sees the SAME self.upstream object across all three + phases (fetch / compute / insert) for a given key — constructed once, + shared. Asserted via object identity, not just a correct result.""" + schema = dj.Schema(f"{prefix}_upstream_tripartite", connection=connection_test) + + @schema + class Source(dj.Lookup): + definition = """ + source_id : int32 + --- + value : int32 + """ + contents = [(1, 100), (2, 200)] + + seen = [] # (source_id, phase, id(self._upstream)) + + @schema + class TriComputed(dj.Computed): + definition = """ + -> Source + --- + result : int32 + """ + + def make_fetch(self, key): + seen.append((key["source_id"], "fetch", id(self._upstream))) + return (self.upstream[Source].fetch1("value"),) + + def make_compute(self, key, value): + seen.append((key["source_id"], "compute", id(self._upstream))) + return (value * 2,) + + def make_insert(self, key, doubled): + seen.append((key["source_id"], "insert", id(self._upstream))) + self.insert1({**key, "result": doubled}) + + TriComputed.populate() + assert (TriComputed & {"source_id": 1}).fetch1("result") == 200 + assert (TriComputed & {"source_id": 2}).fetch1("result") == 400 + + # Every phase that ran for a given key must have observed one and the same + # self.upstream object (not None, not rebuilt per phase). + ids_by_key = {} + for sid, _phase, uid in seen: + ids_by_key.setdefault(sid, set()).add(uid) + assert ids_by_key, "tripartite make did not run" + for sid, ids in ids_by_key.items(): + assert len(ids) == 1, f"source_id={sid}: self.upstream differed across phases: {ids}" + + def test_populate_reserve_jobs_respects_restrictions(clean_autopopulate, subject, experiment): """Regression test for #1413: populate() with reserve_jobs=True must honour restrictions.