Skip to content

Commit d41b75f

Browse files
feat: improve schema grouping labels with fallback logic
- Collect all module names per schema, not just the first - Use Python module name as label if 1:1 mapping with schema - Fall back to database schema name if multiple modules - Strip module prefix from class names when it matches cluster label Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 903e6b2 commit d41b75f

File tree

1 file changed

+52
-19
lines changed

1 file changed

+52
-19
lines changed

src/datajoint/diagram.py

Lines changed: 52 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -379,10 +379,10 @@ def make_dot(self):
379379
direction = config.display.diagram_direction
380380
graph = self._make_graph()
381381

382-
# Build schema mapping: class_name -> (schema_name, module_name)
383-
# Group by database schema, but label with Python module name when available
382+
# Build schema mapping: class_name -> schema_name
383+
# Group by database schema, label with Python module name if 1:1 mapping
384384
schema_map = {} # class_name -> schema_name
385-
module_map = {} # schema_name -> module_name (for cluster labels)
385+
schema_modules = {} # schema_name -> set of module names
386386

387387
for full_name in self.nodes_to_show:
388388
# Extract schema from full table name like `schema`.`table` or "schema"."table"
@@ -392,12 +392,21 @@ def make_dot(self):
392392
class_name = lookup_class_name(full_name, self.context) or full_name
393393
schema_map[class_name] = schema_name
394394

395-
# Try to get Python module name for the cluster label
396-
if schema_name not in module_map:
397-
cls = self._resolve_class(class_name)
398-
if cls is not None and hasattr(cls, "__module__"):
399-
# Use the last part of the module path (e.g., "my_pipeline" from "package.my_pipeline")
400-
module_map[schema_name] = cls.__module__.split(".")[-1]
395+
# Collect all module names for this schema
396+
if schema_name not in schema_modules:
397+
schema_modules[schema_name] = set()
398+
cls = self._resolve_class(class_name)
399+
if cls is not None and hasattr(cls, "__module__"):
400+
module_name = cls.__module__.split(".")[-1]
401+
schema_modules[schema_name].add(module_name)
402+
403+
# Determine cluster labels: use module name if 1:1, else database schema name
404+
cluster_labels = {} # schema_name -> label
405+
for schema_name, modules in schema_modules.items():
406+
if len(modules) == 1:
407+
cluster_labels[schema_name] = next(iter(modules))
408+
else:
409+
cluster_labels[schema_name] = schema_name
401410

402411
# Assign alias nodes (orange dots) to the same schema as their child table
403412
for node, data in graph.nodes(data=True):
@@ -492,7 +501,14 @@ def make_dot(self):
492501
if not q.startswith("#")
493502
)
494503
node.set_tooltip("&#13;".join(description))
495-
node.set_label("<<u>" + name + "</u>>" if node.get("distinguished") == "True" else name)
504+
# Strip module prefix from label if it matches the cluster label
505+
display_name = name
506+
schema_name = schema_map.get(name)
507+
if schema_name and "." in name:
508+
prefix = name.rsplit(".", 1)[0]
509+
if prefix == cluster_labels.get(schema_name):
510+
display_name = name.rsplit(".", 1)[1]
511+
node.set_label("<<u>" + display_name + "</u>>" if node.get("distinguished") == "True" else display_name)
496512
node.set_color(props["color"])
497513
node.set_style("filled")
498514

@@ -525,9 +541,9 @@ def make_dot(self):
525541
schemas[schema_name].append(node)
526542

527543
# Create clusters for each schema
528-
# Use Python module name as label when available, otherwise database schema name
544+
# Use Python module name if 1:1 mapping, otherwise database schema name
529545
for schema_name, nodes in schemas.items():
530-
label = module_map.get(schema_name, schema_name)
546+
label = cluster_labels.get(schema_name, schema_name)
531547
cluster = pydot.Cluster(
532548
f"cluster_{schema_name}",
533549
label=label,
@@ -590,7 +606,7 @@ def make_mermaid(self) -> str:
590606

591607
# Build schema mapping for grouping
592608
schema_map = {} # class_name -> schema_name
593-
module_map = {} # schema_name -> module_name (for subgraph labels)
609+
schema_modules = {} # schema_name -> set of module names
594610

595611
for full_name in self.nodes_to_show:
596612
parts = full_name.replace('"', '`').split('`')
@@ -599,10 +615,21 @@ def make_mermaid(self) -> str:
599615
class_name = lookup_class_name(full_name, self.context) or full_name
600616
schema_map[class_name] = schema_name
601617

602-
if schema_name not in module_map:
603-
cls = self._resolve_class(class_name)
604-
if cls is not None and hasattr(cls, "__module__"):
605-
module_map[schema_name] = cls.__module__.split(".")[-1]
618+
# Collect all module names for this schema
619+
if schema_name not in schema_modules:
620+
schema_modules[schema_name] = set()
621+
cls = self._resolve_class(class_name)
622+
if cls is not None and hasattr(cls, "__module__"):
623+
module_name = cls.__module__.split(".")[-1]
624+
schema_modules[schema_name].add(module_name)
625+
626+
# Determine cluster labels: use module name if 1:1, else database schema name
627+
cluster_labels = {}
628+
for schema_name, modules in schema_modules.items():
629+
if len(modules) == 1:
630+
cluster_labels[schema_name] = next(iter(modules))
631+
else:
632+
cluster_labels[schema_name] = schema_name
606633

607634
# Assign alias nodes to the same schema as their child table
608635
for node, data in graph.nodes(data=True):
@@ -653,15 +680,21 @@ def make_mermaid(self) -> str:
653680

654681
# Add nodes grouped by schema subgraphs
655682
for schema_name, nodes in schemas.items():
656-
label = module_map.get(schema_name, schema_name)
683+
label = cluster_labels.get(schema_name, schema_name)
657684
lines.append(f" subgraph {label}")
658685
for node, data in nodes:
659686
tier = data.get("node_type")
660687
left, right = shape_map.get(tier, ("[", "]"))
661688
cls = tier_class.get(tier, "")
662689
safe_id = node.replace(".", "_").replace(" ", "_")
690+
# Strip module prefix from display name if it matches the cluster label
691+
display_name = node
692+
if "." in node:
693+
prefix = node.rsplit(".", 1)[0]
694+
if prefix == label:
695+
display_name = node.rsplit(".", 1)[1]
663696
class_suffix = f":::{cls}" if cls else ""
664-
lines.append(f" {safe_id}{left}{node}{right}{class_suffix}")
697+
lines.append(f" {safe_id}{left}{display_name}{right}{class_suffix}")
665698
lines.append(" end")
666699

667700
lines.append("")

0 commit comments

Comments
 (0)