Lokasi ngalangkungan proxy:   [ UP ]  
[Ngawartoskeun bug]   [Panyetelan cookie]                
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions src/datajoint/autopopulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,46 @@ class AutoPopulate:
_key_source = None
_allow_insert = False
_jobs = None
_upstream = None # set per-make() by _populate_one; see `upstream` property below

@property
def upstream(self):
"""
Pre-restricted ancestor view for the current ``make(self, key)`` call.

Inside ``make()``, ``self.upstream`` is a ``Diagram`` constructed via
:meth:`Diagram.trace(self & key) <datajoint.Diagram.trace>`. Use
``self.upstream[T]`` to obtain a pre-restricted ``QueryExpression``
(or ``FreeTable``, when indexed by a string) for any ancestor of
``self``.

Reading via ``self.upstream`` is the provenance-safe pattern: the
framework guarantees the restriction matches the current ``key``,
and indexing a non-ancestor table raises ``DataJointError``. See
:doc:`reference/specs/provenance` for the contract.

Raises
------
DataJointError
If accessed outside ``make()`` execution. To construct a trace
explicitly, use ``dj.Diagram.trace(self & key)``.

Examples
--------
::

def make(self, key):
date = self.upstream[Session].fetch1("session_date")
traces = self.upstream[ExtractTraces].to_arrays("trace")
self.insert1({**key, "summary": compute(traces, date)})
"""
if self._upstream is None:
raise DataJointError(
"self.upstream is only available inside make(). "
"Outside make(), construct a trace explicitly: "
"dj.Diagram.trace(self & key)."
)
return self._upstream

class _JobsDescriptor:
"""Descriptor allowing jobs access on both class and instance."""
Expand Down Expand Up @@ -611,6 +651,13 @@ def _populate1(
logger.jobs(f"Making {key} -> {self.full_table_name}")
self.__class__._allow_insert = True

# Pre-construct the upstream view for this make() call. Lazy — only
# `dj.Diagram.trace(self & key)` runs here (graph copy); the
# expensive SQL fetch fires when the user accesses self.upstream[T].
from .diagram import Diagram

self._upstream = Diagram.trace(self & dict(key))

try:
if not is_generator:
make(dict(key), **(make_kwargs or {}))
Expand Down Expand Up @@ -668,6 +715,10 @@ def _populate1(
return True
finally:
self.__class__._allow_insert = False
# Clear the per-make() upstream view so subsequent attribute
# access raises a clear error rather than silently using a
# stale trace from the previous make() call.
self._upstream = None

def progress(self, *restrictions: Any, display: bool = False) -> tuple[int, int]:
"""
Expand Down
239 changes: 239 additions & 0 deletions tests/integration/test_autopopulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,245 @@ def make_insert(self, key, result, scale):
assert row["result"] == 1000 # 200 * 5


# =========================================================================
# #1424: self.upstream pre-restricted ancestor access in make()
# =========================================================================


def test_upstream_provides_pre_restricted_ancestor(prefix, connection_test):
"""make() can read self.upstream[Ancestor] and get pre-restricted data."""
schema = dj.Schema(f"{prefix}_upstream_basic", connection=connection_test)

@schema
class Subject(dj.Lookup):
definition = """
subject_id : int32
---
name : varchar(64)
"""
contents = [(1, "alice"), (2, "bob")]

@schema
class Greeting(dj.Computed):
definition = """
-> Subject
---
greeting : varchar(128)
"""

def make(self, key):
# Provenance-safe read: self.upstream pre-restricted to current key
name = self.upstream[Subject].fetch1("name")
self.insert1({**key, "greeting": f"Hello, {name}!"})

Greeting.populate()
assert (Greeting & {"subject_id": 1}).fetch1("greeting") == "Hello, alice!"
assert (Greeting & {"subject_id": 2}).fetch1("greeting") == "Hello, bob!"


def test_upstream_rejects_non_ancestor(prefix, connection_test):
"""self.upstream[T] for a non-ancestor table raises inside make()."""
schema = dj.Schema(f"{prefix}_upstream_non_ancestor", connection=connection_test)

@schema
class Subject(dj.Lookup):
definition = """
subject_id : int32
"""
contents = [(1,)]

@schema
class Unrelated(dj.Lookup):
definition = """
u_id : int32
"""
contents = [(99,)]

captured_errors: list[Exception] = []

@schema
class Bad(dj.Computed):
definition = """
-> Subject
---
ok : tinyint
"""

def make(self, key):
# class-form lookup (Diagram.__getitem__ class branch)
try:
self.upstream[Unrelated]
except DataJointError as exc:
captured_errors.append(("class", exc))
# string-form lookup (the separate FreeTable/string branch)
try:
self.upstream[Unrelated.full_table_name]
except DataJointError as exc:
captured_errors.append(("string", exc))
# Insert anyway so populate doesn't fail
self.insert1({**key, "ok": 1})

Bad.populate()
# Both the class-form and string-form lookups must reject the non-ancestor.
forms = {form for form, _ in captured_errors}
assert forms == {"class", "string"}, f"expected both branches to raise, got {forms}"
class_err = next(exc for form, exc in captured_errors if form == "class")
assert "not in this trace" in str(class_err).lower()


def test_upstream_unset_outside_make(prefix, connection_test):
"""Accessing self.upstream outside of make() raises a clear error."""
schema = dj.Schema(f"{prefix}_upstream_outside_make", connection=connection_test)

@schema
class Source(dj.Lookup):
definition = """
source_id : int32
"""
contents = [(1,)]

@schema
class Derived(dj.Computed):
definition = """
-> Source
---
val : int32
"""

def make(self, key):
self.insert1({**key, "val": 0})

with pytest.raises(DataJointError, match="only available inside make"):
Derived().upstream


def test_upstream_cleared_after_make(prefix, connection_test):
"""After make() completes, the SAME instance that ran make() has its
self.upstream cleared. Capturing the populate instance is what gives this
test teeth: it would FAIL if the `finally: self._upstream = None` line were
removed (a fresh-instance probe would pass regardless)."""
schema = dj.Schema(f"{prefix}_upstream_cleared", connection=connection_test)

@schema
class Source(dj.Lookup):
definition = """
source_id : int32
"""
contents = [(1,)]

captured = []

@schema
class Derived(dj.Computed):
definition = """
-> Source
---
val : int32
"""

def make(self, key):
captured.append(self) # the actual populate instance
assert self._upstream is not None # set for the duration of make()
self.insert1({**key, "val": 0})

Derived.populate()
assert captured, "make() did not run"
inst = captured[0]
# The finally block must have cleared _upstream on this very instance.
assert inst._upstream is None
with pytest.raises(DataJointError, match="only available inside make"):
inst.upstream


def test_upstream_cleared_after_make_raises(prefix, connection_test):
"""The reset lives in `finally` specifically so it survives an exception in
make(). Force make() to raise and assert the populate instance's
self.upstream is still cleared."""
schema = dj.Schema(f"{prefix}_upstream_exc", connection=connection_test)

@schema
class Source(dj.Lookup):
definition = """
source_id : int32
"""
contents = [(1,)]

captured = []

@schema
class Boom(dj.Computed):
definition = """
-> Source
---
val : int32
"""

def make(self, key):
captured.append(self)
assert self._upstream is not None
raise RuntimeError("make failed on purpose")

with pytest.raises(RuntimeError, match="make failed on purpose"):
Boom.populate(suppress_errors=False)
assert captured, "make() did not run"
inst = captured[0]
# Cleared by the finally block even though make() raised.
assert inst._upstream is None
with pytest.raises(DataJointError, match="only available inside make"):
inst.upstream


def test_upstream_seen_across_tripartite_make(prefix, connection_test):
"""The tripartite make() sees the SAME self.upstream object across all three
phases (fetch / compute / insert) for a given key — constructed once,
shared. Asserted via object identity, not just a correct result."""
schema = dj.Schema(f"{prefix}_upstream_tripartite", connection=connection_test)

@schema
class Source(dj.Lookup):
definition = """
source_id : int32
---
value : int32
"""
contents = [(1, 100), (2, 200)]

seen = [] # (source_id, phase, id(self._upstream))

@schema
class TriComputed(dj.Computed):
definition = """
-> Source
---
result : int32
"""

def make_fetch(self, key):
seen.append((key["source_id"], "fetch", id(self._upstream)))
return (self.upstream[Source].fetch1("value"),)

def make_compute(self, key, value):
seen.append((key["source_id"], "compute", id(self._upstream)))
return (value * 2,)

def make_insert(self, key, doubled):
seen.append((key["source_id"], "insert", id(self._upstream)))
self.insert1({**key, "result": doubled})

TriComputed.populate()
assert (TriComputed & {"source_id": 1}).fetch1("result") == 200
assert (TriComputed & {"source_id": 2}).fetch1("result") == 400

# Every phase that ran for a given key must have observed one and the same
# self.upstream object (not None, not rebuilt per phase).
ids_by_key = {}
for sid, _phase, uid in seen:
ids_by_key.setdefault(sid, set()).add(uid)
assert ids_by_key, "tripartite make did not run"
for sid, ids in ids_by_key.items():
assert len(ids) == 1, f"source_id={sid}: self.upstream differed across phases: {ids}"


def test_populate_reserve_jobs_respects_restrictions(clean_autopopulate, subject, experiment):
"""Regression test for #1413: populate() with reserve_jobs=True must honour restrictions.

Expand Down
Loading