@@ -379,10 +379,10 @@ def make_dot(self):
379379 direction = config .display .diagram_direction
380380 graph = self ._make_graph ()
381381
382- # Build schema mapping: class_name -> ( schema_name, module_name)
383- # Group by database schema, but label with Python module name when available
382+ # Build schema mapping: class_name -> schema_name
383+ # Group by database schema, label with Python module name if 1:1 mapping
384384 schema_map = {} # class_name -> schema_name
385- module_map = {} # schema_name -> module_name (for cluster labels)
385+ schema_modules = {} # schema_name -> set of module names
386386
387387 for full_name in self .nodes_to_show :
388388 # Extract schema from full table name like `schema`.`table` or "schema"."table"
@@ -392,12 +392,21 @@ def make_dot(self):
392392 class_name = lookup_class_name (full_name , self .context ) or full_name
393393 schema_map [class_name ] = schema_name
394394
395- # Try to get Python module name for the cluster label
396- if schema_name not in module_map :
397- cls = self ._resolve_class (class_name )
398- if cls is not None and hasattr (cls , "__module__" ):
399- # Use the last part of the module path (e.g., "my_pipeline" from "package.my_pipeline")
400- module_map [schema_name ] = cls .__module__ .split ("." )[- 1 ]
395+ # Collect all module names for this schema
396+ if schema_name not in schema_modules :
397+ schema_modules [schema_name ] = set ()
398+ cls = self ._resolve_class (class_name )
399+ if cls is not None and hasattr (cls , "__module__" ):
400+ module_name = cls .__module__ .split ("." )[- 1 ]
401+ schema_modules [schema_name ].add (module_name )
402+
403+ # Determine cluster labels: use module name if 1:1, else database schema name
404+ cluster_labels = {} # schema_name -> label
405+ for schema_name , modules in schema_modules .items ():
406+ if len (modules ) == 1 :
407+ cluster_labels [schema_name ] = next (iter (modules ))
408+ else :
409+ cluster_labels [schema_name ] = schema_name
401410
402411 # Assign alias nodes (orange dots) to the same schema as their child table
403412 for node , data in graph .nodes (data = True ):
@@ -492,7 +501,14 @@ def make_dot(self):
492501 if not q .startswith ("#" )
493502 )
494503 node .set_tooltip (" " .join (description ))
495- node .set_label ("<<u>" + name + "</u>>" if node .get ("distinguished" ) == "True" else name )
504+ # Strip module prefix from label if it matches the cluster label
505+ display_name = name
506+ schema_name = schema_map .get (name )
507+ if schema_name and "." in name :
508+ prefix = name .rsplit ("." , 1 )[0 ]
509+ if prefix == cluster_labels .get (schema_name ):
510+ display_name = name .rsplit ("." , 1 )[1 ]
511+ node .set_label ("<<u>" + display_name + "</u>>" if node .get ("distinguished" ) == "True" else display_name )
496512 node .set_color (props ["color" ])
497513 node .set_style ("filled" )
498514
@@ -525,9 +541,9 @@ def make_dot(self):
525541 schemas [schema_name ].append (node )
526542
527543 # Create clusters for each schema
528- # Use Python module name as label when available , otherwise database schema name
544+ # Use Python module name if 1:1 mapping , otherwise database schema name
529545 for schema_name , nodes in schemas .items ():
530- label = module_map .get (schema_name , schema_name )
546+ label = cluster_labels .get (schema_name , schema_name )
531547 cluster = pydot .Cluster (
532548 f"cluster_{ schema_name } " ,
533549 label = label ,
@@ -590,7 +606,7 @@ def make_mermaid(self) -> str:
590606
591607 # Build schema mapping for grouping
592608 schema_map = {} # class_name -> schema_name
593- module_map = {} # schema_name -> module_name (for subgraph labels)
609+ schema_modules = {} # schema_name -> set of module names
594610
595611 for full_name in self .nodes_to_show :
596612 parts = full_name .replace ('"' , '`' ).split ('`' )
@@ -599,10 +615,21 @@ def make_mermaid(self) -> str:
599615 class_name = lookup_class_name (full_name , self .context ) or full_name
600616 schema_map [class_name ] = schema_name
601617
602- if schema_name not in module_map :
603- cls = self ._resolve_class (class_name )
604- if cls is not None and hasattr (cls , "__module__" ):
605- module_map [schema_name ] = cls .__module__ .split ("." )[- 1 ]
618+ # Collect all module names for this schema
619+ if schema_name not in schema_modules :
620+ schema_modules [schema_name ] = set ()
621+ cls = self ._resolve_class (class_name )
622+ if cls is not None and hasattr (cls , "__module__" ):
623+ module_name = cls .__module__ .split ("." )[- 1 ]
624+ schema_modules [schema_name ].add (module_name )
625+
626+ # Determine cluster labels: use module name if 1:1, else database schema name
627+ cluster_labels = {}
628+ for schema_name , modules in schema_modules .items ():
629+ if len (modules ) == 1 :
630+ cluster_labels [schema_name ] = next (iter (modules ))
631+ else :
632+ cluster_labels [schema_name ] = schema_name
606633
607634 # Assign alias nodes to the same schema as their child table
608635 for node , data in graph .nodes (data = True ):
@@ -653,15 +680,21 @@ def make_mermaid(self) -> str:
653680
654681 # Add nodes grouped by schema subgraphs
655682 for schema_name , nodes in schemas .items ():
656- label = module_map .get (schema_name , schema_name )
683+ label = cluster_labels .get (schema_name , schema_name )
657684 lines .append (f" subgraph { label } " )
658685 for node , data in nodes :
659686 tier = data .get ("node_type" )
660687 left , right = shape_map .get (tier , ("[" , "]" ))
661688 cls = tier_class .get (tier , "" )
662689 safe_id = node .replace ("." , "_" ).replace (" " , "_" )
690+ # Strip module prefix from display name if it matches the cluster label
691+ display_name = node
692+ if "." in node :
693+ prefix = node .rsplit ("." , 1 )[0 ]
694+ if prefix == label :
695+ display_name = node .rsplit ("." , 1 )[1 ]
663696 class_suffix = f":::{ cls } " if cls else ""
664- lines .append (f" { safe_id } { left } { node } { right } { class_suffix } " )
697+ lines .append (f" { safe_id } { left } { display_name } { right } { class_suffix } " )
665698 lines .append (" end" )
666699
667700 lines .append ("" )
0 commit comments