@@ -103,6 +103,8 @@ def __init__(self, source, context=None) -> None:
103103 if isinstance (source , Diagram ):
104104 # copy constructor
105105 self .nodes_to_show = set (source .nodes_to_show )
106+ self ._explicit_nodes = set (source ._explicit_nodes )
107+ self ._is_collapsed = source ._is_collapsed
106108 self .context = source .context
107109 super ().__init__ (source )
108110 return
@@ -130,6 +132,8 @@ def __init__(self, source, context=None) -> None:
130132
131133 # Enumerate nodes from all the items in the list
132134 self .nodes_to_show = set ()
135+ self ._explicit_nodes = set () # nodes that should never be collapsed
136+ self ._is_collapsed = False # whether this diagram's nodes should be collapsed when combined
133137 try :
134138 self .nodes_to_show .add (source .full_table_name )
135139 except AttributeError :
@@ -181,6 +185,31 @@ def is_part(part, master):
181185 self .nodes_to_show .update (n for n in self .nodes () if any (is_part (n , m ) for m in self .nodes_to_show ))
182186 return self
183187
188+ def collapse (self ) -> "Diagram" :
189+ """
190+ Mark this diagram for collapsing when combined with other diagrams.
191+
192+ When a collapsed diagram is added to a non-collapsed diagram, its nodes
193+ are shown as a single collapsed node per schema, unless they also appear
194+ in the non-collapsed diagram (expanded wins).
195+
196+ Returns
197+ -------
198+ Diagram
199+ A copy of this diagram marked for collapsing.
200+
201+ Examples
202+ --------
203+ >>> # Show schema1 expanded, schema2 collapsed into single nodes
204+ >>> dj.Diagram(schema1) + dj.Diagram(schema2).collapse()
205+
206+ >>> # Explicitly expand one table from schema2
207+ >>> dj.Diagram(schema1) + dj.Diagram(TableFromSchema2) + dj.Diagram(schema2).collapse()
208+ """
209+ result = Diagram (self )
210+ result ._is_collapsed = True
211+ return result
212+
184213 def __add__ (self , arg ) -> "Diagram" :
185214 """
186215 Union or downstream expansion.
@@ -195,21 +224,36 @@ def __add__(self, arg) -> "Diagram":
195224 Diagram
196225 Combined or expanded diagram.
197226 """
198- self = Diagram (self ) # copy
227+ result = Diagram (self ) # copy
199228 try :
200- self .nodes_to_show .update (arg .nodes_to_show )
229+ result .nodes_to_show .update (arg .nodes_to_show )
230+ # Handle collapse: nodes from non-collapsed diagrams are explicit (expanded)
231+ if not self ._is_collapsed :
232+ result ._explicit_nodes .update (self .nodes_to_show )
233+ else :
234+ result ._explicit_nodes .update (self ._explicit_nodes )
235+ if not arg ._is_collapsed :
236+ result ._explicit_nodes .update (arg .nodes_to_show )
237+ else :
238+ result ._explicit_nodes .update (arg ._explicit_nodes )
239+ # Result is not collapsed (it's a combination)
240+ result ._is_collapsed = False
201241 except AttributeError :
202242 try :
203- self .nodes_to_show .add (arg .full_table_name )
243+ result .nodes_to_show .add (arg .full_table_name )
244+ result ._explicit_nodes .add (arg .full_table_name )
204245 except AttributeError :
205246 for i in range (arg ):
206- new = nx .algorithms .boundary .node_boundary (self , self .nodes_to_show )
247+ new = nx .algorithms .boundary .node_boundary (result , result .nodes_to_show )
207248 if not new :
208249 break
209250 # add nodes referenced by aliased nodes
210- new .update (nx .algorithms .boundary .node_boundary (self , (a for a in new if a .isdigit ())))
211- self .nodes_to_show .update (new )
212- return self
251+ new .update (nx .algorithms .boundary .node_boundary (result , (a for a in new if a .isdigit ())))
252+ result .nodes_to_show .update (new )
253+ # Expanded nodes from + N expansion are explicit
254+ if not self ._is_collapsed :
255+ result ._explicit_nodes = result .nodes_to_show .copy ()
256+ return result
213257
214258 def __sub__ (self , arg ) -> "Diagram" :
215259 """
@@ -305,6 +349,131 @@ def _make_graph(self) -> nx.DiGraph:
305349 nx .relabel_nodes (graph , mapping , copy = False )
306350 return graph
307351
352+ def _apply_collapse (self , graph : nx .DiGraph ) -> tuple [nx .DiGraph , dict [str , str ]]:
353+ """
354+ Apply collapse logic to the graph.
355+
356+ Nodes in nodes_to_show but not in _explicit_nodes are collapsed into
357+ single schema nodes.
358+
359+ Parameters
360+ ----------
361+ graph : nx.DiGraph
362+ The graph from _make_graph().
363+
364+ Returns
365+ -------
366+ tuple[nx.DiGraph, dict[str, str]]
367+ Modified graph and mapping of collapsed schema labels to their table count.
368+ """
369+ if not self ._explicit_nodes or self ._explicit_nodes == self .nodes_to_show :
370+ # No collapse needed
371+ return graph , {}
372+
373+ # Map full_table_names to class_names
374+ full_to_class = {
375+ node : lookup_class_name (node , self .context ) or node
376+ for node in self .nodes_to_show
377+ }
378+ class_to_full = {v : k for k , v in full_to_class .items ()}
379+
380+ # Identify explicit class names (should be expanded)
381+ explicit_class_names = {
382+ full_to_class .get (node , node ) for node in self ._explicit_nodes
383+ }
384+
385+ # Identify nodes to collapse (class names)
386+ nodes_to_collapse = set (graph .nodes ()) - explicit_class_names
387+
388+ if not nodes_to_collapse :
389+ return graph , {}
390+
391+ # Group collapsed nodes by schema
392+ collapsed_by_schema = {} # schema_name -> list of class_names
393+ for class_name in nodes_to_collapse :
394+ full_name = class_to_full .get (class_name )
395+ if full_name :
396+ parts = full_name .replace ('"' , '`' ).split ('`' )
397+ if len (parts ) >= 2 :
398+ schema_name = parts [1 ]
399+ if schema_name not in collapsed_by_schema :
400+ collapsed_by_schema [schema_name ] = []
401+ collapsed_by_schema [schema_name ].append (class_name )
402+
403+ if not collapsed_by_schema :
404+ return graph , {}
405+
406+ # Determine labels for collapsed schemas
407+ schema_modules = {}
408+ for schema_name , class_names in collapsed_by_schema .items ():
409+ schema_modules [schema_name ] = set ()
410+ for class_name in class_names :
411+ cls = self ._resolve_class (class_name )
412+ if cls is not None and hasattr (cls , "__module__" ):
413+ module_name = cls .__module__ .split ("." )[- 1 ]
414+ schema_modules [schema_name ].add (module_name )
415+
416+ collapsed_labels = {} # schema_name -> label
417+ collapsed_counts = {} # label -> count of tables
418+ for schema_name , modules in schema_modules .items ():
419+ if len (modules ) == 1 :
420+ label = next (iter (modules ))
421+ else :
422+ label = schema_name
423+ collapsed_labels [schema_name ] = label
424+ collapsed_counts [label ] = len (collapsed_by_schema [schema_name ])
425+
426+ # Create new graph with collapsed nodes
427+ new_graph = nx .DiGraph ()
428+
429+ # Map old node names to new names (collapsed nodes -> schema label)
430+ node_mapping = {}
431+ for node in graph .nodes ():
432+ full_name = class_to_full .get (node )
433+ if full_name :
434+ parts = full_name .replace ('"' , '`' ).split ('`' )
435+ if len (parts ) >= 2 and node in nodes_to_collapse :
436+ schema_name = parts [1 ]
437+ node_mapping [node ] = collapsed_labels [schema_name ]
438+ else :
439+ node_mapping [node ] = node
440+ else :
441+ # Alias nodes - check if they should be collapsed
442+ # An alias node should be collapsed if ALL its neighbors are collapsed
443+ neighbors = set (graph .predecessors (node )) | set (graph .successors (node ))
444+ if neighbors and neighbors <= nodes_to_collapse :
445+ # Get schema from first neighbor
446+ neighbor = next (iter (neighbors ))
447+ full_name = class_to_full .get (neighbor )
448+ if full_name :
449+ parts = full_name .replace ('"' , '`' ).split ('`' )
450+ if len (parts ) >= 2 :
451+ schema_name = parts [1 ]
452+ node_mapping [node ] = collapsed_labels [schema_name ]
453+ continue
454+ node_mapping [node ] = node
455+
456+ # Add nodes
457+ added_collapsed = set ()
458+ for old_node , new_node in node_mapping .items ():
459+ if new_node in collapsed_counts :
460+ # This is a collapsed schema node
461+ if new_node not in added_collapsed :
462+ new_graph .add_node (new_node , node_type = None , collapsed = True ,
463+ table_count = collapsed_counts [new_node ])
464+ added_collapsed .add (new_node )
465+ else :
466+ new_graph .add_node (new_node , ** graph .nodes [old_node ])
467+
468+ # Add edges (avoiding self-loops and duplicates)
469+ for src , dest , data in graph .edges (data = True ):
470+ new_src = node_mapping [src ]
471+ new_dest = node_mapping [dest ]
472+ if new_src != new_dest and not new_graph .has_edge (new_src , new_dest ):
473+ new_graph .add_edge (new_src , new_dest , ** data )
474+
475+ return new_graph , collapsed_counts
476+
308477 def _resolve_class (self , name : str ):
309478 """
310479 Safely resolve a table class from a dotted name without eval().
@@ -379,6 +548,9 @@ def make_dot(self):
379548 direction = config .display .diagram_direction
380549 graph = self ._make_graph ()
381550
551+ # Apply collapse logic if needed
552+ graph , collapsed_counts = self ._apply_collapse (graph )
553+
382554 # Build schema mapping: class_name -> schema_name
383555 # Group by database schema, label with Python module name if 1:1 mapping
384556 schema_map = {} # class_name -> schema_name
@@ -474,8 +646,22 @@ def make_dot(self):
474646 size = 0.1 * scale ,
475647 fixed = False ,
476648 ),
649+ "collapsed" : dict (
650+ shape = "box3d" ,
651+ color = "#80808060" ,
652+ fontcolor = "#404040" ,
653+ fontsize = round (scale * 10 ),
654+ size = 0.5 * scale ,
655+ fixed = False ,
656+ ),
477657 }
478- node_props = {node : label_props [d ["node_type" ]] for node , d in dict (graph .nodes (data = True )).items ()}
658+ # Build node_props, handling collapsed nodes specially
659+ node_props = {}
660+ for node , d in graph .nodes (data = True ):
661+ if d .get ("collapsed" ):
662+ node_props [node ] = label_props ["collapsed" ]
663+ else :
664+ node_props [node ] = label_props [d ["node_type" ]]
479665
480666 self ._encapsulate_node_names (graph )
481667 self ._encapsulate_edge_attributes (graph )
@@ -492,23 +678,32 @@ def make_dot(self):
492678 node .set_fixedsize ("shape" if props ["fixed" ] else False )
493679 node .set_width (props ["size" ])
494680 node .set_height (props ["size" ])
495- cls = self ._resolve_class (name )
496- if cls is not None :
497- description = cls ().describe (context = self .context ).split ("\n " )
498- description = (
499- ("-" * 30 if q .startswith ("---" ) else (q .replace ("->" , "→" ) if "->" in q else q .split (":" )[0 ]))
500- for q in description
501- if not q .startswith ("#" )
502- )
503- node .set_tooltip (" " .join (description ))
504- # Strip module prefix from label if it matches the cluster label
505- display_name = name
506- schema_name = schema_map .get (name )
507- if schema_name and "." in name :
508- prefix = name .rsplit ("." , 1 )[0 ]
509- if prefix == cluster_labels .get (schema_name ):
510- display_name = name .rsplit ("." , 1 )[1 ]
511- node .set_label ("<<u>" + display_name + "</u>>" if node .get ("distinguished" ) == "True" else display_name )
681+
682+ # Handle collapsed nodes specially
683+ node_data = graph .nodes .get (f'"{ name } "' , {})
684+ if node_data .get ("collapsed" ):
685+ table_count = node_data .get ("table_count" , 0 )
686+ label = f"{ name } \\ n({ table_count } tables)" if table_count != 1 else f"{ name } \\ n(1 table)"
687+ node .set_label (label )
688+ node .set_tooltip (f"Collapsed schema: { table_count } tables" )
689+ else :
690+ cls = self ._resolve_class (name )
691+ if cls is not None :
692+ description = cls ().describe (context = self .context ).split ("\n " )
693+ description = (
694+ ("-" * 30 if q .startswith ("---" ) else (q .replace ("->" , "→" ) if "->" in q else q .split (":" )[0 ]))
695+ for q in description
696+ if not q .startswith ("#" )
697+ )
698+ node .set_tooltip (" " .join (description ))
699+ # Strip module prefix from label if it matches the cluster label
700+ display_name = name
701+ schema_name = schema_map .get (name )
702+ if schema_name and "." in name :
703+ prefix = name .rsplit ("." , 1 )[0 ]
704+ if prefix == cluster_labels .get (schema_name ):
705+ display_name = name .rsplit ("." , 1 )[1 ]
706+ node .set_label ("<<u>" + display_name + "</u>>" if node .get ("distinguished" ) == "True" else display_name )
512707 node .set_color (props ["color" ])
513708 node .set_style ("filled" )
514709
@@ -520,11 +715,12 @@ def make_dot(self):
520715 if props is None :
521716 raise DataJointError ("Could not find edge with source '{}' and destination '{}'" .format (src , dest ))
522717 edge .set_color ("#00000040" )
523- edge .set_style ("solid" if props ["primary" ] else "dashed" )
524- master_part = graph .nodes [dest ]["node_type" ] is Part and dest .startswith (src + "." )
718+ edge .set_style ("solid" if props .get ("primary" ) else "dashed" )
719+ dest_node_type = graph .nodes [dest ].get ("node_type" )
720+ master_part = dest_node_type is Part and dest .startswith (src + "." )
525721 edge .set_weight (3 if master_part else 1 )
526722 edge .set_arrowhead ("none" )
527- edge .set_penwidth (0.75 if props [ "multi" ] else 2 )
723+ edge .set_penwidth (0.75 if props . get ( "multi" ) else 2 )
528724
529725 # Group nodes into schema clusters (always on)
530726 if schema_map :
@@ -604,6 +800,9 @@ def make_mermaid(self) -> str:
604800 graph = self ._make_graph ()
605801 direction = config .display .diagram_direction
606802
803+ # Apply collapse logic if needed
804+ graph , collapsed_counts = self ._apply_collapse (graph )
805+
607806 # Build schema mapping for grouping
608807 schema_map = {} # class_name -> schema_name
609808 schema_modules = {} # schema_name -> set of module names
@@ -646,6 +845,7 @@ def make_mermaid(self) -> str:
646845 lines .append (" classDef computed fill:#FFB6C1,stroke:#8B0000" )
647846 lines .append (" classDef imported fill:#ADD8E6,stroke:#00008B" )
648847 lines .append (" classDef part fill:#FFFFFF,stroke:#000000" )
848+ lines .append (" classDef collapsed fill:#808080,stroke:#404040" )
649849 lines .append ("" )
650850
651851 # Shape mapping: Manual=box, Computed/Imported=stadium, Lookup/Part=box
@@ -669,14 +869,25 @@ def make_mermaid(self) -> str:
669869 None : "" ,
670870 }
671871
672- # Group nodes by schema into subgraphs
872+ # Group nodes by schema into subgraphs (only non-collapsed nodes)
673873 schemas = {}
874+ collapsed_nodes = []
674875 for node , data in graph .nodes (data = True ):
675- schema_name = schema_map .get (node )
676- if schema_name :
677- if schema_name not in schemas :
678- schemas [schema_name ] = []
679- schemas [schema_name ].append ((node , data ))
876+ if data .get ("collapsed" ):
877+ collapsed_nodes .append ((node , data ))
878+ else :
879+ schema_name = schema_map .get (node )
880+ if schema_name :
881+ if schema_name not in schemas :
882+ schemas [schema_name ] = []
883+ schemas [schema_name ].append ((node , data ))
884+
885+ # Add collapsed nodes (not in subgraphs)
886+ for node , data in collapsed_nodes :
887+ safe_id = node .replace ("." , "_" ).replace (" " , "_" )
888+ table_count = data .get ("table_count" , 0 )
889+ count_text = f"{ table_count } tables" if table_count != 1 else "1 table"
890+ lines .append (f" { safe_id } [[\" { node } <br/>({ count_text } )\" ]]:::collapsed" )
680891
681892 # Add nodes grouped by schema subgraphs
682893 for schema_name , nodes in schemas .items ():
0 commit comments