Skip to content

Commit 80489fc

Browse files
feat: add collapse() method for high-level pipeline views
- Add collapse() method to mark diagrams for collapsing when combined - Collapsed schemas appear as single nodes showing table count - "Expanded wins" - nodes in non-collapsed diagrams stay expanded - Works with both Graphviz and Mermaid output - Use box3d shape for collapsed nodes in Graphviz Example: dj.Diagram(schema1) + dj.Diagram(schema2).collapse() Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent d41b75f commit 80489fc

File tree

1 file changed

+245
-34
lines changed

1 file changed

+245
-34
lines changed

src/datajoint/diagram.py

Lines changed: 245 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ def __init__(self, source, context=None) -> None:
103103
if isinstance(source, Diagram):
104104
# copy constructor
105105
self.nodes_to_show = set(source.nodes_to_show)
106+
self._explicit_nodes = set(source._explicit_nodes)
107+
self._is_collapsed = source._is_collapsed
106108
self.context = source.context
107109
super().__init__(source)
108110
return
@@ -130,6 +132,8 @@ def __init__(self, source, context=None) -> None:
130132

131133
# Enumerate nodes from all the items in the list
132134
self.nodes_to_show = set()
135+
self._explicit_nodes = set() # nodes that should never be collapsed
136+
self._is_collapsed = False # whether this diagram's nodes should be collapsed when combined
133137
try:
134138
self.nodes_to_show.add(source.full_table_name)
135139
except AttributeError:
@@ -181,6 +185,31 @@ def is_part(part, master):
181185
self.nodes_to_show.update(n for n in self.nodes() if any(is_part(n, m) for m in self.nodes_to_show))
182186
return self
183187

188+
def collapse(self) -> "Diagram":
189+
"""
190+
Mark this diagram for collapsing when combined with other diagrams.
191+
192+
When a collapsed diagram is added to a non-collapsed diagram, its nodes
193+
are shown as a single collapsed node per schema, unless they also appear
194+
in the non-collapsed diagram (expanded wins).
195+
196+
Returns
197+
-------
198+
Diagram
199+
A copy of this diagram marked for collapsing.
200+
201+
Examples
202+
--------
203+
>>> # Show schema1 expanded, schema2 collapsed into single nodes
204+
>>> dj.Diagram(schema1) + dj.Diagram(schema2).collapse()
205+
206+
>>> # Explicitly expand one table from schema2
207+
>>> dj.Diagram(schema1) + dj.Diagram(TableFromSchema2) + dj.Diagram(schema2).collapse()
208+
"""
209+
result = Diagram(self)
210+
result._is_collapsed = True
211+
return result
212+
184213
def __add__(self, arg) -> "Diagram":
185214
"""
186215
Union or downstream expansion.
@@ -195,21 +224,36 @@ def __add__(self, arg) -> "Diagram":
195224
Diagram
196225
Combined or expanded diagram.
197226
"""
198-
self = Diagram(self) # copy
227+
result = Diagram(self) # copy
199228
try:
200-
self.nodes_to_show.update(arg.nodes_to_show)
229+
result.nodes_to_show.update(arg.nodes_to_show)
230+
# Handle collapse: nodes from non-collapsed diagrams are explicit (expanded)
231+
if not self._is_collapsed:
232+
result._explicit_nodes.update(self.nodes_to_show)
233+
else:
234+
result._explicit_nodes.update(self._explicit_nodes)
235+
if not arg._is_collapsed:
236+
result._explicit_nodes.update(arg.nodes_to_show)
237+
else:
238+
result._explicit_nodes.update(arg._explicit_nodes)
239+
# Result is not collapsed (it's a combination)
240+
result._is_collapsed = False
201241
except AttributeError:
202242
try:
203-
self.nodes_to_show.add(arg.full_table_name)
243+
result.nodes_to_show.add(arg.full_table_name)
244+
result._explicit_nodes.add(arg.full_table_name)
204245
except AttributeError:
205246
for i in range(arg):
206-
new = nx.algorithms.boundary.node_boundary(self, self.nodes_to_show)
247+
new = nx.algorithms.boundary.node_boundary(result, result.nodes_to_show)
207248
if not new:
208249
break
209250
# add nodes referenced by aliased nodes
210-
new.update(nx.algorithms.boundary.node_boundary(self, (a for a in new if a.isdigit())))
211-
self.nodes_to_show.update(new)
212-
return self
251+
new.update(nx.algorithms.boundary.node_boundary(result, (a for a in new if a.isdigit())))
252+
result.nodes_to_show.update(new)
253+
# Expanded nodes from + N expansion are explicit
254+
if not self._is_collapsed:
255+
result._explicit_nodes = result.nodes_to_show.copy()
256+
return result
213257

214258
def __sub__(self, arg) -> "Diagram":
215259
"""
@@ -305,6 +349,131 @@ def _make_graph(self) -> nx.DiGraph:
305349
nx.relabel_nodes(graph, mapping, copy=False)
306350
return graph
307351

352+
def _apply_collapse(self, graph: nx.DiGraph) -> tuple[nx.DiGraph, dict[str, str]]:
353+
"""
354+
Apply collapse logic to the graph.
355+
356+
Nodes in nodes_to_show but not in _explicit_nodes are collapsed into
357+
single schema nodes.
358+
359+
Parameters
360+
----------
361+
graph : nx.DiGraph
362+
The graph from _make_graph().
363+
364+
Returns
365+
-------
366+
tuple[nx.DiGraph, dict[str, str]]
367+
Modified graph and mapping of collapsed schema labels to their table count.
368+
"""
369+
if not self._explicit_nodes or self._explicit_nodes == self.nodes_to_show:
370+
# No collapse needed
371+
return graph, {}
372+
373+
# Map full_table_names to class_names
374+
full_to_class = {
375+
node: lookup_class_name(node, self.context) or node
376+
for node in self.nodes_to_show
377+
}
378+
class_to_full = {v: k for k, v in full_to_class.items()}
379+
380+
# Identify explicit class names (should be expanded)
381+
explicit_class_names = {
382+
full_to_class.get(node, node) for node in self._explicit_nodes
383+
}
384+
385+
# Identify nodes to collapse (class names)
386+
nodes_to_collapse = set(graph.nodes()) - explicit_class_names
387+
388+
if not nodes_to_collapse:
389+
return graph, {}
390+
391+
# Group collapsed nodes by schema
392+
collapsed_by_schema = {} # schema_name -> list of class_names
393+
for class_name in nodes_to_collapse:
394+
full_name = class_to_full.get(class_name)
395+
if full_name:
396+
parts = full_name.replace('"', '`').split('`')
397+
if len(parts) >= 2:
398+
schema_name = parts[1]
399+
if schema_name not in collapsed_by_schema:
400+
collapsed_by_schema[schema_name] = []
401+
collapsed_by_schema[schema_name].append(class_name)
402+
403+
if not collapsed_by_schema:
404+
return graph, {}
405+
406+
# Determine labels for collapsed schemas
407+
schema_modules = {}
408+
for schema_name, class_names in collapsed_by_schema.items():
409+
schema_modules[schema_name] = set()
410+
for class_name in class_names:
411+
cls = self._resolve_class(class_name)
412+
if cls is not None and hasattr(cls, "__module__"):
413+
module_name = cls.__module__.split(".")[-1]
414+
schema_modules[schema_name].add(module_name)
415+
416+
collapsed_labels = {} # schema_name -> label
417+
collapsed_counts = {} # label -> count of tables
418+
for schema_name, modules in schema_modules.items():
419+
if len(modules) == 1:
420+
label = next(iter(modules))
421+
else:
422+
label = schema_name
423+
collapsed_labels[schema_name] = label
424+
collapsed_counts[label] = len(collapsed_by_schema[schema_name])
425+
426+
# Create new graph with collapsed nodes
427+
new_graph = nx.DiGraph()
428+
429+
# Map old node names to new names (collapsed nodes -> schema label)
430+
node_mapping = {}
431+
for node in graph.nodes():
432+
full_name = class_to_full.get(node)
433+
if full_name:
434+
parts = full_name.replace('"', '`').split('`')
435+
if len(parts) >= 2 and node in nodes_to_collapse:
436+
schema_name = parts[1]
437+
node_mapping[node] = collapsed_labels[schema_name]
438+
else:
439+
node_mapping[node] = node
440+
else:
441+
# Alias nodes - check if they should be collapsed
442+
# An alias node should be collapsed if ALL its neighbors are collapsed
443+
neighbors = set(graph.predecessors(node)) | set(graph.successors(node))
444+
if neighbors and neighbors <= nodes_to_collapse:
445+
# Get schema from first neighbor
446+
neighbor = next(iter(neighbors))
447+
full_name = class_to_full.get(neighbor)
448+
if full_name:
449+
parts = full_name.replace('"', '`').split('`')
450+
if len(parts) >= 2:
451+
schema_name = parts[1]
452+
node_mapping[node] = collapsed_labels[schema_name]
453+
continue
454+
node_mapping[node] = node
455+
456+
# Add nodes
457+
added_collapsed = set()
458+
for old_node, new_node in node_mapping.items():
459+
if new_node in collapsed_counts:
460+
# This is a collapsed schema node
461+
if new_node not in added_collapsed:
462+
new_graph.add_node(new_node, node_type=None, collapsed=True,
463+
table_count=collapsed_counts[new_node])
464+
added_collapsed.add(new_node)
465+
else:
466+
new_graph.add_node(new_node, **graph.nodes[old_node])
467+
468+
# Add edges (avoiding self-loops and duplicates)
469+
for src, dest, data in graph.edges(data=True):
470+
new_src = node_mapping[src]
471+
new_dest = node_mapping[dest]
472+
if new_src != new_dest and not new_graph.has_edge(new_src, new_dest):
473+
new_graph.add_edge(new_src, new_dest, **data)
474+
475+
return new_graph, collapsed_counts
476+
308477
def _resolve_class(self, name: str):
309478
"""
310479
Safely resolve a table class from a dotted name without eval().
@@ -379,6 +548,9 @@ def make_dot(self):
379548
direction = config.display.diagram_direction
380549
graph = self._make_graph()
381550

551+
# Apply collapse logic if needed
552+
graph, collapsed_counts = self._apply_collapse(graph)
553+
382554
# Build schema mapping: class_name -> schema_name
383555
# Group by database schema, label with Python module name if 1:1 mapping
384556
schema_map = {} # class_name -> schema_name
@@ -474,8 +646,22 @@ def make_dot(self):
474646
size=0.1 * scale,
475647
fixed=False,
476648
),
649+
"collapsed": dict(
650+
shape="box3d",
651+
color="#80808060",
652+
fontcolor="#404040",
653+
fontsize=round(scale * 10),
654+
size=0.5 * scale,
655+
fixed=False,
656+
),
477657
}
478-
node_props = {node: label_props[d["node_type"]] for node, d in dict(graph.nodes(data=True)).items()}
658+
# Build node_props, handling collapsed nodes specially
659+
node_props = {}
660+
for node, d in graph.nodes(data=True):
661+
if d.get("collapsed"):
662+
node_props[node] = label_props["collapsed"]
663+
else:
664+
node_props[node] = label_props[d["node_type"]]
479665

480666
self._encapsulate_node_names(graph)
481667
self._encapsulate_edge_attributes(graph)
@@ -492,23 +678,32 @@ def make_dot(self):
492678
node.set_fixedsize("shape" if props["fixed"] else False)
493679
node.set_width(props["size"])
494680
node.set_height(props["size"])
495-
cls = self._resolve_class(name)
496-
if cls is not None:
497-
description = cls().describe(context=self.context).split("\n")
498-
description = (
499-
("-" * 30 if q.startswith("---") else (q.replace("->", "&#8594;") if "->" in q else q.split(":")[0]))
500-
for q in description
501-
if not q.startswith("#")
502-
)
503-
node.set_tooltip("&#13;".join(description))
504-
# Strip module prefix from label if it matches the cluster label
505-
display_name = name
506-
schema_name = schema_map.get(name)
507-
if schema_name and "." in name:
508-
prefix = name.rsplit(".", 1)[0]
509-
if prefix == cluster_labels.get(schema_name):
510-
display_name = name.rsplit(".", 1)[1]
511-
node.set_label("<<u>" + display_name + "</u>>" if node.get("distinguished") == "True" else display_name)
681+
682+
# Handle collapsed nodes specially
683+
node_data = graph.nodes.get(f'"{name}"', {})
684+
if node_data.get("collapsed"):
685+
table_count = node_data.get("table_count", 0)
686+
label = f"{name}\\n({table_count} tables)" if table_count != 1 else f"{name}\\n(1 table)"
687+
node.set_label(label)
688+
node.set_tooltip(f"Collapsed schema: {table_count} tables")
689+
else:
690+
cls = self._resolve_class(name)
691+
if cls is not None:
692+
description = cls().describe(context=self.context).split("\n")
693+
description = (
694+
("-" * 30 if q.startswith("---") else (q.replace("->", "&#8594;") if "->" in q else q.split(":")[0]))
695+
for q in description
696+
if not q.startswith("#")
697+
)
698+
node.set_tooltip("&#13;".join(description))
699+
# Strip module prefix from label if it matches the cluster label
700+
display_name = name
701+
schema_name = schema_map.get(name)
702+
if schema_name and "." in name:
703+
prefix = name.rsplit(".", 1)[0]
704+
if prefix == cluster_labels.get(schema_name):
705+
display_name = name.rsplit(".", 1)[1]
706+
node.set_label("<<u>" + display_name + "</u>>" if node.get("distinguished") == "True" else display_name)
512707
node.set_color(props["color"])
513708
node.set_style("filled")
514709

@@ -520,11 +715,12 @@ def make_dot(self):
520715
if props is None:
521716
raise DataJointError("Could not find edge with source '{}' and destination '{}'".format(src, dest))
522717
edge.set_color("#00000040")
523-
edge.set_style("solid" if props["primary"] else "dashed")
524-
master_part = graph.nodes[dest]["node_type"] is Part and dest.startswith(src + ".")
718+
edge.set_style("solid" if props.get("primary") else "dashed")
719+
dest_node_type = graph.nodes[dest].get("node_type")
720+
master_part = dest_node_type is Part and dest.startswith(src + ".")
525721
edge.set_weight(3 if master_part else 1)
526722
edge.set_arrowhead("none")
527-
edge.set_penwidth(0.75 if props["multi"] else 2)
723+
edge.set_penwidth(0.75 if props.get("multi") else 2)
528724

529725
# Group nodes into schema clusters (always on)
530726
if schema_map:
@@ -604,6 +800,9 @@ def make_mermaid(self) -> str:
604800
graph = self._make_graph()
605801
direction = config.display.diagram_direction
606802

803+
# Apply collapse logic if needed
804+
graph, collapsed_counts = self._apply_collapse(graph)
805+
607806
# Build schema mapping for grouping
608807
schema_map = {} # class_name -> schema_name
609808
schema_modules = {} # schema_name -> set of module names
@@ -646,6 +845,7 @@ def make_mermaid(self) -> str:
646845
lines.append(" classDef computed fill:#FFB6C1,stroke:#8B0000")
647846
lines.append(" classDef imported fill:#ADD8E6,stroke:#00008B")
648847
lines.append(" classDef part fill:#FFFFFF,stroke:#000000")
848+
lines.append(" classDef collapsed fill:#808080,stroke:#404040")
649849
lines.append("")
650850

651851
# Shape mapping: Manual=box, Computed/Imported=stadium, Lookup/Part=box
@@ -669,14 +869,25 @@ def make_mermaid(self) -> str:
669869
None: "",
670870
}
671871

672-
# Group nodes by schema into subgraphs
872+
# Group nodes by schema into subgraphs (only non-collapsed nodes)
673873
schemas = {}
874+
collapsed_nodes = []
674875
for node, data in graph.nodes(data=True):
675-
schema_name = schema_map.get(node)
676-
if schema_name:
677-
if schema_name not in schemas:
678-
schemas[schema_name] = []
679-
schemas[schema_name].append((node, data))
876+
if data.get("collapsed"):
877+
collapsed_nodes.append((node, data))
878+
else:
879+
schema_name = schema_map.get(node)
880+
if schema_name:
881+
if schema_name not in schemas:
882+
schemas[schema_name] = []
883+
schemas[schema_name].append((node, data))
884+
885+
# Add collapsed nodes (not in subgraphs)
886+
for node, data in collapsed_nodes:
887+
safe_id = node.replace(".", "_").replace(" ", "_")
888+
table_count = data.get("table_count", 0)
889+
count_text = f"{table_count} tables" if table_count != 1 else "1 table"
890+
lines.append(f" {safe_id}[[\"{node}<br/>({count_text})\"]]:::collapsed")
680891

681892
# Add nodes grouped by schema subgraphs
682893
for schema_name, nodes in schemas.items():

0 commit comments

Comments
 (0)