From 8d9d2422b65988157301e865b2674cef62399bb1 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Tue, 23 Jun 2026 07:40:14 -0500 Subject: [PATCH 1/3] feat(#1423): Diagram.trace() for upstream restriction propagation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements T2.2.a of the provenance trinity (datajoint-docs#183). Upstream mirror of Diagram.cascade(): walks the FK graph from a restricted seed to every ancestor with OR convergence — an ancestor entity is included if reachable through any FK path from the seed. Reuses the upward propagation primitives added by #1468 (_apply_propagation_rule_upward / _find_real_edge_props) applied here in a generalized form (any child → any parent, not just Part → Master). Branch note: stacked on fix/1429-cascade-part-part-renamed-fk (#1468) for the upward primitives. Will rebase onto master after #1468 lands. What's added: - src/datajoint/diagram.py: - New @classmethod Diagram.trace(table_expr) — mirror of cascade(), walks ancestors instead of descendants, trims to ancestor subgraph. - New _propagate_restrictions_upstream(start_node) — multi-pass walk over in_edges, applies the upward rules at each real edge. Alias-node transparent. - New __getitem__(key) — supports both Table subclass/instance (returns pre-restricted QueryExpression) and string (returns pre-restricted FreeTable). Raises DataJointError for tables outside the trace's subgraph. - Bugfix in _apply_propagation_rule_upward Backward Rule 3: previous code projected child to its OWN PK (child_ft.proj()) which excluded non-primary FK columns. Now projects to the FK columns via proj(*attr_map.keys()), correctly carrying them into the parent restriction for non-primary-FK cases. Caught by test_trace_or_convergence_two_paths. - src/datajoint/dependencies.py: - New load_all_upstream() — symmetric to load_all_downstream. Iteratively discovers upstream schemas reachable via reverse FK edges, expanding the graph until convergence. - src/datajoint/adapters/{base,mysql,postgres}.py: - New find_upstream_schemas_sql(schemas_list) on each adapter, symmetric to find_downstream_schemas_sql. - tests/integration/test_trace.py (new, 8 tests covering single-hop, multi-hop, renamed FK, OR convergence across two paths, non-ancestor rejection, string indexing → FreeTable, counts(), leaf-table seed). All 8 trace tests pass on MySQL. Regression: test_cascade_delete + test_cascading_delete + test_dependencies + test_semantic_matching — 36 tests pass, no regressions from the Rule 3 fix. Slated for DataJoint 2.3. --- src/datajoint/adapters/base.py | 23 +++ src/datajoint/adapters/mysql.py | 10 + src/datajoint/adapters/postgres.py | 14 ++ src/datajoint/dependencies.py | 29 +++ src/datajoint/diagram.py | 238 ++++++++++++++++++++++- tests/integration/test_trace.py | 293 +++++++++++++++++++++++++++++ 6 files changed, 605 insertions(+), 2 deletions(-) create mode 100644 tests/integration/test_trace.py diff --git a/src/datajoint/adapters/base.py b/src/datajoint/adapters/base.py index da4779543..e79a5d4df 100644 --- a/src/datajoint/adapters/base.py +++ b/src/datajoint/adapters/base.py @@ -850,6 +850,29 @@ def find_downstream_schemas_sql(self, schemas_list: str) -> str: raise NotImplementedError ... + def find_upstream_schemas_sql(self, schemas_list: str) -> str: + """ + Generate query to find schemas that the given schemas reference via FK. + + Used to discover unloaded schemas that the loaded ones depend on + (the upstream / ancestor direction). Symmetric to + :meth:`find_downstream_schemas_sql`. + + Parameters + ---------- + schemas_list : str + Comma-separated, quoted schema names for an IN clause. + + Returns + ------- + str + SQL query returning rows with a single column ``schema_name`` + containing distinct schema names that are referenced by the + given schemas. + """ + raise NotImplementedError + ... + @abstractmethod def get_constraint_info_sql(self, constraint_name: str, schema_name: str, table_name: str) -> str: """ diff --git a/src/datajoint/adapters/mysql.py b/src/datajoint/adapters/mysql.py index f035ba87f..4d2d4ca73 100644 --- a/src/datajoint/adapters/mysql.py +++ b/src/datajoint/adapters/mysql.py @@ -696,6 +696,16 @@ def find_downstream_schemas_sql(self, schemas_list: str) -> str: f"AND table_schema NOT IN ({schemas_list})" ) + def find_upstream_schemas_sql(self, schemas_list: str) -> str: + """Find schemas that the given schemas reference via FK.""" + return ( + f"SELECT DISTINCT referenced_table_schema as schema_name " + f"FROM information_schema.key_column_usage " + f"WHERE table_schema IN ({schemas_list}) " + f"AND referenced_table_schema IS NOT NULL " + f"AND referenced_table_schema NOT IN ({schemas_list})" + ) + def get_constraint_info_sql(self, constraint_name: str, schema_name: str, table_name: str) -> str: """Query to get FK constraint details from information_schema.""" return ( diff --git a/src/datajoint/adapters/postgres.py b/src/datajoint/adapters/postgres.py index 543e972d3..1dc062bda 100644 --- a/src/datajoint/adapters/postgres.py +++ b/src/datajoint/adapters/postgres.py @@ -861,6 +861,20 @@ def find_downstream_schemas_sql(self, schemas_list: str) -> str: f"AND ns1.nspname NOT IN ({schemas_list})" ) + def find_upstream_schemas_sql(self, schemas_list: str) -> str: + """Find schemas that the given schemas reference via FK.""" + return ( + f"SELECT DISTINCT ns2.nspname as schema_name " + f"FROM pg_constraint c " + f"JOIN pg_class cl1 ON c.conrelid = cl1.oid " + f"JOIN pg_namespace ns1 ON cl1.relnamespace = ns1.oid " + f"JOIN pg_class cl2 ON c.confrelid = cl2.oid " + f"JOIN pg_namespace ns2 ON cl2.relnamespace = ns2.oid " + f"WHERE c.contype = 'f' " + f"AND ns1.nspname IN ({schemas_list}) " + f"AND ns2.nspname NOT IN ({schemas_list})" + ) + def get_constraint_info_sql(self, constraint_name: str, schema_name: str, table_name: str) -> str: """ Query to get FK constraint details from information_schema. diff --git a/src/datajoint/dependencies.py b/src/datajoint/dependencies.py index 08fb50e1b..9b67c00d0 100644 --- a/src/datajoint/dependencies.py +++ b/src/datajoint/dependencies.py @@ -259,6 +259,35 @@ def load_all_downstream(self) -> None: self.load(force=True, schema_names=known_schemas) + def load_all_upstream(self) -> None: + """ + Load dependencies including all upstream schemas referenced via FK chains. + + Iteratively discovers schemas that the currently loaded schemas + reference, expanding the dependency graph until no new schemas + are found. This ensures that upstream restriction propagation + (``Diagram.trace()``) reaches all ancestor tables, including + those in schemas the user has not explicitly activated. + + Called automatically by ``Diagram.trace()``. Symmetric to + :meth:`load_all_downstream`. + """ + adapter = self._conn.adapter + known_schemas = set(self._conn.schemas) + if not known_schemas: + self.load() + return + + while True: + schemas_list = ", ".join(adapter.quote_string(s) for s in known_schemas) + result = self._conn.query(adapter.find_upstream_schemas_sql(schemas_list)) + new_schemas = {row[0] for row in result} - known_schemas + if not new_schemas: + break + known_schemas |= new_schemas + + self.load(force=True, schema_names=known_schemas) + def topo_sort(self) -> list[str]: """ Return table names in topological order. diff --git a/src/datajoint/diagram.py b/src/datajoint/diagram.py index 9b6c659f3..3b2ac09b0 100644 --- a/src/datajoint/diagram.py +++ b/src/datajoint/diagram.py @@ -387,6 +387,235 @@ def cascade(cls, table_expr, part_integrity="enforce"): result._expanded_nodes &= keep return result + @classmethod + def trace(cls, table_expr): + """ + Create an upstream-trace diagram for a (restricted) table expression. + + The upstream mirror of :meth:`cascade`. Walks the FK graph upward + from the seed, propagating the restriction to every ancestor with + OR convergence — an ancestor entity is included if reachable through + *any* FK path from the seed. + + Reuses the upward propagation rules + (``_apply_propagation_rule_upward``) defined alongside the cascade + engine, applied here in a generalized form (any child → any parent, + not just Part → Master). + + Parameters + ---------- + table_expr : QueryExpression + A (possibly restricted) table expression. + (e.g., ``Spectrum & key``). + + Returns + ------- + Diagram + New Diagram restricted to the seed and its ancestors, with + per-ancestor restrictions accumulated through the FK graph. + Use ``diagram[T]`` to obtain a pre-restricted + ``QueryExpression`` for an ancestor, or ``diagram.counts()`` + to preview row counts per ancestor. + + Examples + -------- + >>> trace = dj.Diagram.trace(Spectrum & {"recording_id": 5}) + >>> trace[Session].fetch1("session_date") + >>> trace.counts() # entity counts per ancestor + >>> trace["schema.Session"] # FreeTable, when class isn't in scope + + See Also + -------- + :meth:`cascade` — the downstream mirror. + """ + conn = table_expr.connection + conn.dependencies.load_all_upstream() + node = table_expr.full_table_name + + result = cls.__new__(cls) + nx.DiGraph.__init__(result, conn.dependencies) + result._connection = conn + result.context = {} + result._cascade_restrictions = {} # trace uses cascade-shape storage (OR semantics) + result._restrict_conditions = {} + result._restriction_attrs = {} + result._mode = "trace" + + # Include seed + all ancestors + ancestors = set(nx.ancestors(result, node)) | {node} + result.nodes_to_show = ancestors + result._expanded_nodes = set(ancestors) + + # Seed restriction + restriction = AndList(table_expr.restriction) + result._cascade_restrictions[node] = [restriction] if restriction else [] + result._restriction_attrs[node] = set(table_expr.restriction_attributes) + + # Propagate upstream + result._propagate_restrictions_upstream(node) + + # Trim graph to trace subgraph: only restricted tables (seed + ancestors) + # plus alias nodes connecting them. + keep = set(result._cascade_restrictions) + for alias in (n for n in result.nodes() if n.isdigit()): + if set(result.predecessors(alias)) & keep and set(result.successors(alias)) & keep: + keep.add(alias) + result.remove_nodes_from(set(result.nodes()) - keep) + result.nodes_to_show &= keep + result._expanded_nodes &= keep + return result + + def _propagate_restrictions_upstream(self, start_node): + """ + Propagate the seed's restriction upstream through the FK graph. + + Symmetric to :meth:`_propagate_restrictions` but walks ``in_edges`` + instead of ``out_edges`` and applies the upward rules + (``_apply_propagation_rule_upward``) at each real edge. Multiple + passes until no new ancestor is restricted; termination is + guaranteed because the dependency graph is a DAG. + """ + sorted_nodes = topo_sort(self) + # Only propagate through ancestors of start_node + allowed_nodes = {start_node} | set(nx.ancestors(self, start_node)) + propagated_edges = set() + + restrictions = self._cascade_restrictions + + any_new = True + while any_new: + any_new = False + + # Walk in reverse topological order so children are processed + # before their parents — when we reach a parent, its restriction + # accumulates from all of its (already-processed) children. + for node in reversed(sorted_nodes): + if node not in restrictions or node not in allowed_nodes: + continue + + child_ft = self._restricted_table(node) + child_attrs = self._restriction_attrs.get(node, set()) + + for parent, _, edge_props in self.in_edges(node, data=True): + edge_key = (parent, node) + if edge_key in propagated_edges: + continue + propagated_edges.add(edge_key) + + if parent not in allowed_nodes: + continue + + if isinstance(parent, str) and parent.isdigit(): + # Alias node — find the real parent on the far side. + # The alias has its own in_edges; the props on both + # half-edges are identical, so we can use either. + for real_parent, _, real_edge_props in self.in_edges(parent, data=True): + real_edge_key = (real_parent, parent, node) + if real_edge_key in propagated_edges: + continue + propagated_edges.add(real_edge_key) + if real_parent not in allowed_nodes: + continue + attr_map = real_edge_props.get("attr_map", {}) + aliased = real_edge_props.get("aliased", False) + was_new = real_parent not in restrictions + self._apply_propagation_rule_upward( + child_ft, + child_attrs, + real_parent, + attr_map, + aliased, + "cascade", # OR semantics for trace + restrictions, + ) + if was_new and real_parent in restrictions: + any_new = True + else: + attr_map = edge_props.get("attr_map", {}) + aliased = edge_props.get("aliased", False) + was_new = parent not in restrictions + self._apply_propagation_rule_upward( + child_ft, + child_attrs, + parent, + attr_map, + aliased, + "cascade", + restrictions, + ) + if was_new and parent in restrictions: + any_new = True + + def __getitem__(self, key): + """ + Return a pre-restricted query expression (or FreeTable) for an + ancestor table in this trace. + + Parameters + ---------- + key : type or str + A Table subclass (e.g. ``Session``) — returns a pre-restricted + ``QueryExpression``. Or a string giving the table's class name + or fully-qualified SQL name — returns a pre-restricted + ``FreeTable``. + + Returns + ------- + QueryExpression or FreeTable + The ancestor's table restricted to rows reachable via FK from + the seed of this trace. + + Raises + ------ + DataJointError + If the requested table is not in the trace's subgraph (i.e. + not an ancestor of the seed, and not the seed itself). + + Examples + -------- + >>> trace = dj.Diagram.trace(Spectrum & key) + >>> trace[Session].fetch1("session_date") # class index + >>> trace["my_schema.Session"].to_dicts() # string index → FreeTable + """ + from .table import Table + + # Resolve `key` to a full table name + if isinstance(key, type) and issubclass(key, Table): + full_name = key.full_table_name + elif isinstance(key, Table): + full_name = key.full_table_name + elif isinstance(key, str): + # Accept either a class name (resolve via context) or a full SQL name + if "`" in key or '"' in key: + full_name = key + else: + # Class name — search graph nodes for a matching tail + candidates = [ + n + for n in self.nodes() + if not (isinstance(n, str) and n.isdigit()) and n.lower().rstrip('`"').endswith(key.lower()) + ] + if not candidates: + raise DataJointError(f"Table {key!r} is not in this trace's subgraph " f"(not an ancestor of the seed).") + if len(candidates) > 1: + raise DataJointError( + f"Ambiguous table reference {key!r}: matches " f"{', '.join(candidates)}. Use a fully-qualified name." + ) + full_name = candidates[0] + else: + raise DataJointError(f"trace[...] expects a Table class, Table instance, or string; " f"got {type(key).__name__}.") + + if full_name not in self._cascade_restrictions: + raise DataJointError(f"Table {full_name} is not in this trace's subgraph " f"(not an ancestor of the seed).") + + # For class-typed key, return a restricted class instance; for string, + # return a FreeTable. + if isinstance(key, (type, Table)): + ft = self._restricted_table(full_name) + return ft + else: + return self._restricted_table(full_name) + def _restricted_table(self, node): """ Return a FreeTable for ``node`` with this diagram's restrictions applied. @@ -648,8 +877,13 @@ def _apply_propagation_rule_upward(self, child_ft, child_attrs, parent_node, att restrictions.setdefault(parent_node, AndList()).append(parent_item) parent_attrs = set(attr_map.values()) # parent's PK column names else: - # Backward Rule 3: project child to parent PK - parent_item = child_ft.proj() + # Backward Rule 3: project child to its FK columns (which by name + # match parent's PK columns in the non-aliased case). For primary + # FKs (attr_map.keys() ⊆ child_pk) this is a no-op since + # ``proj()`` already returns the PK. For non-primary FKs this + # explicitly carries the FK columns into the projection so the + # subsequent restriction on the parent joins on the right columns. + parent_item = child_ft.proj(*attr_map.keys()) if mode == "cascade": restrictions.setdefault(parent_node, []).append(parent_item) else: diff --git a/tests/integration/test_trace.py b/tests/integration/test_trace.py new file mode 100644 index 000000000..787635bfe --- /dev/null +++ b/tests/integration/test_trace.py @@ -0,0 +1,293 @@ +""" +Integration tests for ``Diagram.trace()`` — upstream restriction propagation. + +The upstream mirror of ``Diagram.cascade()``. Walks the FK graph from a +restricted seed to every ancestor with OR convergence. Reuses the upward +propagation rules (U1/U2/U3 in cascade.md) added by #1468. +""" + +import pytest + +import datajoint as dj +from datajoint.errors import DataJointError + + +@pytest.fixture(scope="function") +def schema_by_backend(connection_by_backend, db_creds_by_backend, request): + """Create a fresh schema for each trace test.""" + backend = db_creds_by_backend["backend"] + import time + + test_id = str(int(time.time() * 1000))[-8:] + schema_name = f"djtest_trace_{backend}_{test_id}"[:64] + + if connection_by_backend.is_connected: + try: + connection_by_backend.query( + f"DROP DATABASE IF EXISTS {connection_by_backend.adapter.quote_identifier(schema_name)}" + ) + except Exception: + pass + + schema = dj.Schema(schema_name, connection=connection_by_backend) + yield schema + + if connection_by_backend.is_connected: + try: + connection_by_backend.query( + f"DROP DATABASE IF EXISTS {connection_by_backend.adapter.quote_identifier(schema_name)}" + ) + except Exception: + pass + + +def test_trace_single_hop(schema_by_backend): + """trace(Child & key)[Parent] returns Parent restricted via the FK.""" + + @schema_by_backend + class Parent(dj.Manual): + definition = """ + parent_id : int32 + --- + name : varchar(64) + """ + + @schema_by_backend + class Child(dj.Manual): + definition = """ + -> Parent + child_id : int32 + """ + + Parent.insert([(1, "alice"), (2, "bob")]) + Child.insert([(1, 10), (1, 11), (2, 20)]) + + trace = dj.Diagram.trace(Child & {"parent_id": 1, "child_id": 10}) + + # Seed itself + assert len(trace[Child]) == 1 + + # Ancestor: Parent restricted to the rows that contributed to the seed + assert len(trace[Parent]) == 1 + assert trace[Parent].fetch1("parent_id") == 1 + + +def test_trace_multi_hop(schema_by_backend): + """trace walks through intermediate ancestors (Grandparent ← Parent ← Child).""" + + @schema_by_backend + class Grandparent(dj.Manual): + definition = """ + gp_id : int32 + """ + + @schema_by_backend + class Parent(dj.Manual): + definition = """ + -> Grandparent + parent_id : int32 + """ + + @schema_by_backend + class Child(dj.Manual): + definition = """ + -> Parent + child_id : int32 + """ + + Grandparent.insert([(1,), (2,)]) + Parent.insert([(1, 10), (1, 11), (2, 20)]) + Child.insert([(1, 10, 100), (1, 11, 110), (2, 20, 200)]) + + trace = dj.Diagram.trace(Child & {"gp_id": 1, "parent_id": 10, "child_id": 100}) + + # All three ancestors restricted to the one contributing tuple per level + assert len(trace[Child]) == 1 + assert len(trace[Parent]) == 1 + assert len(trace[Grandparent]) == 1 + assert trace[Grandparent].fetch1("gp_id") == 1 + + +def test_trace_renamed_fk(schema_by_backend): + """Renamed FK (.proj(...)) — the upward rule reverses the rename.""" + + @schema_by_backend + class Animal(dj.Manual): + definition = """ + animal_id : int32 + --- + species : varchar(64) + """ + + @schema_by_backend + class Observation(dj.Manual): + definition = """ + obs_id : int32 + --- + -> Animal.proj(subject_id='animal_id') + measurement : float64 + """ + + Animal.insert([(1, "Mouse"), (2, "Rat")]) + Observation.insert([(10, 1, 1.5), (11, 1, 2.5), (20, 2, 3.0)]) + + # Observation columns: obs_id, subject_id (renamed), measurement. + # No `animal_id` column on Observation — the upward walk must reverse the rename. + trace = dj.Diagram.trace(Observation & {"obs_id": 10}) + + assert len(trace[Animal]) == 1 + assert trace[Animal].fetch1("animal_id") == 1 + assert trace[Animal].fetch1("species") == "Mouse" + + +def test_trace_or_convergence_two_paths(schema_by_backend): + """Two FK paths from child to the same ancestor → OR (union) at the ancestor.""" + + @schema_by_backend + class Source(dj.Manual): + definition = """ + source_id : int32 + """ + + @schema_by_backend + class Downstream(dj.Manual): + definition = """ + downstream_id : int32 + --- + -> Source + -> Source.proj(comparison_src='source_id') + """ + + Source.insert([(1,), (2,), (3,)]) + # Downstream rows reference Source via two columns; OR convergence means the + # ancestor is restricted to the UNION of contributors across both FK paths. + Downstream.insert( + [ + (100, 1, 2), # primary source=1, comparison_src=2 + (101, 3, 3), # primary source=3, comparison_src=3 + ] + ) + + trace = dj.Diagram.trace(Downstream & {"downstream_id": 100}) + + # Source is restricted via BOTH FK paths from row 100 → {1, 2} + contributing = set(trace[Source].fetch("source_id")) + assert contributing == {1, 2} + + +def test_trace_rejects_non_ancestor(schema_by_backend): + """Indexing into a table that isn't in the trace's subgraph raises.""" + + @schema_by_backend + class A(dj.Manual): + definition = """ + a_id : int32 + """ + + @schema_by_backend + class B(dj.Manual): + definition = """ + b_id : int32 + """ + + @schema_by_backend + class C(dj.Manual): + definition = """ + -> A + c_id : int32 + """ + + A.insert([(1,)]) + B.insert([(99,)]) + C.insert([(1, 10)]) + + trace = dj.Diagram.trace(C & {"a_id": 1, "c_id": 10}) + + # A is an ancestor — OK + assert len(trace[A]) == 1 + + # B is unrelated — should raise + with pytest.raises(DataJointError, match="not in this trace"): + trace[B] + + +def test_trace_string_indexing_returns_freetable(schema_by_backend): + """trace[str] returns a FreeTable (no class needed in caller scope).""" + from datajoint.table import FreeTable + + @schema_by_backend + class Parent(dj.Manual): + definition = """ + parent_id : int32 + """ + + @schema_by_backend + class Child(dj.Manual): + definition = """ + -> Parent + child_id : int32 + """ + + Parent.insert([(1,), (2,)]) + Child.insert([(1, 10), (2, 20)]) + + trace = dj.Diagram.trace(Child & {"parent_id": 1, "child_id": 10}) + + # String accepts the SQL-quoted full name + parent_via_string = trace[Parent.full_table_name] + assert isinstance(parent_via_string, FreeTable) + assert len(parent_via_string) == 1 + + +def test_trace_counts(schema_by_backend): + """trace.counts() reports per-ancestor row counts under the seed's restriction.""" + + @schema_by_backend + class Grandparent(dj.Manual): + definition = """ + gp_id : int32 + """ + + @schema_by_backend + class Parent(dj.Manual): + definition = """ + -> Grandparent + parent_id : int32 + """ + + @schema_by_backend + class Child(dj.Manual): + definition = """ + -> Parent + child_id : int32 + """ + + Grandparent.insert([(1,), (2,)]) + Parent.insert([(1, 10), (1, 11), (2, 20)]) + Child.insert([(1, 10, 100), (1, 11, 110), (2, 20, 200)]) + + trace = dj.Diagram.trace(Child & {"gp_id": 1}) + counts = trace.counts() + + assert counts[Grandparent.full_table_name] == 1 + assert counts[Parent.full_table_name] == 2 + assert counts[Child.full_table_name] == 2 + + +def test_trace_seed_with_no_ancestors(schema_by_backend): + """Tracing from a table with no FK parents → trace contains only the seed.""" + + @schema_by_backend + class Standalone(dj.Manual): + definition = """ + std_id : int32 + """ + + Standalone.insert([(1,), (2,)]) + + trace = dj.Diagram.trace(Standalone & {"std_id": 1}) + + # Only the seed is in the trace + assert len(trace[Standalone]) == 1 + counts = trace.counts() + assert counts == {Standalone.full_table_name: 1} From 03dd0a7b8e993c8a6fb46ba56934c0cdfe4dd47c Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Tue, 23 Jun 2026 08:19:58 -0500 Subject: [PATCH 2/3] fix(#1423): gate Diagram.__getitem__ on _mode == "trace" The trace-mode __getitem__ I added shadowed networkx.DiGraph's standard adjacency-dict lookup for ALL Diagrams, not just trace results. ERD tests (and any other code that does diagram[node_name] for adjacency) were getting DataJointError("not in this trace's subgraph") instead of the adjacency dict. Fix: short-circuit non-trace diagrams (no _mode attribute or _mode != "trace") to super().__getitem__(key) before any trace-specific logic runs. Tests: - 5 previously-failing erd tests now pass (test_erd, test_diagram_algebra, test_repr_svg, test_make_image, test_part_table_parsing). - 8/8 trace tests still pass. --- src/datajoint/diagram.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/datajoint/diagram.py b/src/datajoint/diagram.py index 3b2ac09b0..b2572cfaf 100644 --- a/src/datajoint/diagram.py +++ b/src/datajoint/diagram.py @@ -551,6 +551,10 @@ def __getitem__(self, key): Return a pre-restricted query expression (or FreeTable) for an ancestor table in this trace. + Only meaningful for trace diagrams (constructed via + :meth:`Diagram.trace`). For ordinary diagrams, defers to + :class:`networkx.DiGraph`'s adjacency-dict lookup. + Parameters ---------- key : type or str @@ -577,6 +581,12 @@ def __getitem__(self, key): >>> trace[Session].fetch1("session_date") # class index >>> trace["my_schema.Session"].to_dicts() # string index → FreeTable """ + # Non-trace diagrams: defer to networkx adjacency lookup so existing + # `diagram[node_name]` patterns (used in diagram algebra, ERD tests) + # keep working. + if getattr(self, "_mode", None) != "trace": + return super().__getitem__(key) + from .table import Table # Resolve `key` to a full table name From 2d81534abbdae86a77ee2fab67aac31de5e21ff0 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Thu, 2 Jul 2026 10:34:06 -0500 Subject: [PATCH 3/3] test(#1423): add diamond OR + cross-schema trace tests; fix Rule 3 docstring Addresses the before-merge coverage @ttngu207 and @MilagrosMarin flagged on #1471 (the two they singled out as worth adding before merge): - test_trace_multi_hop_diamond_or_convergence: an ancestor reached via two MULTI-HOP arms (Leaf -> {Left, Right} -> Root), asserting the OR-union {1, 2}. Guards the multi-pass accumulation in the reverse-topo walk, where a regression would silently drop an arm and yield a subset. - test_trace_cross_schema_ancestor: seed and ancestor in different schemas, exercising load_all_upstream's unloaded-ancestor-schema discovery. Runs on both MySQL and PostgreSQL via schema_by_backend. Also fixes the Backward Rule 3 docstring drift (diagram.py): it described child.proj() but the code projects child.proj(*attr_map.keys()) to carry the FK columns. Upward Part-of-Part chains and an isolated non-aliased secondary-FK test remain as noted follow-ups. --- src/datajoint/diagram.py | 4 +- tests/integration/test_trace.py | 98 +++++++++++++++++++++++++++++++++ 2 files changed, 101 insertions(+), 1 deletion(-) diff --git a/src/datajoint/diagram.py b/src/datajoint/diagram.py index b2572cfaf..e00d0328c 100644 --- a/src/datajoint/diagram.py +++ b/src/datajoint/diagram.py @@ -863,7 +863,9 @@ def _apply_propagation_rule_upward(self, child_ft, child_attrs, parent_node, att ``child.proj(**{parent: child for child, parent in attr_map.items()})`` — reverses the renaming so the result has parent's column names. 3. Non-aliased AND child restriction attrs ⊄ parent PK: - ``child.proj()`` — project child to parent's PK columns. + ``child.proj(*attr_map.keys())`` — project child onto its FK columns + (which, being non-aliased, share names with parent's PK columns) so + the subsequent restriction on the parent joins on the right columns. """ parent_pk = self.nodes[parent_node].get("primary_key", set()) diff --git a/tests/integration/test_trace.py b/tests/integration/test_trace.py index 787635bfe..06b948396 100644 --- a/tests/integration/test_trace.py +++ b/tests/integration/test_trace.py @@ -274,6 +274,104 @@ class Child(dj.Manual): assert counts[Child.full_table_name] == 2 +def test_trace_multi_hop_diamond_or_convergence(schema_by_backend): + """Diamond: an ancestor reached via two MULTI-HOP paths → OR-union across + both arms. Unlike test_trace_or_convergence_two_paths (adjacent two-edge + case), this forces the reverse-topo walk to accumulate a contributor for the + same ancestor across separate multi-pass arms. A regression that dropped an + OR arm would yield a subset here.""" + + @schema_by_backend + class Root(dj.Manual): + definition = """ + root_id : int32 + """ + + @schema_by_backend + class Left(dj.Manual): + definition = """ + -> Root + left_id : int32 + """ + + @schema_by_backend + class Right(dj.Manual): + # renamed FK avoids the root_id name collision when Leaf reconverges + definition = """ + -> Root.proj(root_id2='root_id') + right_id : int32 + """ + + @schema_by_backend + class Leaf(dj.Manual): + definition = """ + -> Left + -> Right + leaf_id : int32 + """ + + Root.insert([(1,), (2,), (3,)]) + Left.insert([(1, 10)]) # Left row → Root 1 + Right.insert([(2, 20)]) # Right row (root_id2=2) → Root 2 + # Leaf PK order: root_id, left_id, root_id2, right_id, leaf_id + Leaf.insert([(1, 10, 2, 20, 100)]) + + trace = dj.Diagram.trace(Leaf & {"leaf_id": 100}) + + # Root reached via Leaf→Left→Root (root_id=1) OR Leaf→Right→Root + # (root_id2=2 reversed to root_id=2). Union = {1, 2}; Root 3 excluded. + contributing = set(trace[Root].fetch("root_id")) + assert contributing == {1, 2} + + +def test_trace_cross_schema_ancestor(schema_by_backend, connection_by_backend): + """Ancestor in a DIFFERENT schema than the seed → load_all_upstream must + discover the unloaded ancestor schema via reverse FK-schema lookup.""" + import time + + backend = connection_by_backend.adapter + other_name = f"djtest_trace_other_{str(int(time.time() * 1000))[-8:]}"[:64] + if connection_by_backend.is_connected: + try: + connection_by_backend.query(f"DROP DATABASE IF EXISTS {backend.quote_identifier(other_name)}") + except Exception: + pass + other = dj.Schema(other_name, connection=connection_by_backend) + + try: + + @schema_by_backend + class Upstream(dj.Manual): + definition = """ + up_id : int32 + --- + label : varchar(32) + """ + + @other + class Downstream(dj.Manual): + # cross-schema FK: Downstream lives in `other`, Upstream in schema_by_backend + definition = """ + -> Upstream + down_id : int32 + """ + + Upstream.insert([(1, "a"), (2, "b")]) + Downstream.insert([(1, 10), (2, 20)]) + + trace = dj.Diagram.trace(Downstream & {"up_id": 1, "down_id": 10}) + + assert len(trace[Upstream]) == 1 + assert trace[Upstream].fetch1("up_id") == 1 + assert trace[Upstream].fetch1("label") == "a" + finally: + if connection_by_backend.is_connected: + try: + connection_by_backend.query(f"DROP DATABASE IF EXISTS {backend.quote_identifier(other_name)}") + except Exception: + pass + + def test_trace_seed_with_no_ancestors(schema_by_backend): """Tracing from a table with no FK parents → trace contains only the seed."""