@@ -226,7 +226,12 @@ def __add__(self, arg) -> "Diagram":
226226 """
227227 result = Diagram (self ) # copy
228228 try :
229+ # Merge nodes and edges from the other diagram
230+ result .add_nodes_from (arg .nodes (data = True ))
231+ result .add_edges_from (arg .edges (data = True ))
229232 result .nodes_to_show .update (arg .nodes_to_show )
233+ # Merge contexts for class name lookups
234+ result .context = {** result .context , ** arg .context }
230235 # Handle collapse: nodes from non-collapsed diagrams are explicit (expanded)
231236 if not self ._is_collapsed :
232237 result ._explicit_nodes .update (self .nodes_to_show )
@@ -326,18 +331,20 @@ def _make_graph(self) -> nx.DiGraph:
326331 """
327332 # mark "distinguished" tables, i.e. those that introduce new primary key
328333 # attributes
329- for name in self .nodes_to_show :
334+ # Filter nodes_to_show to only include nodes that exist in the graph
335+ valid_nodes = self .nodes_to_show .intersection (set (self .nodes ()))
336+ for name in valid_nodes :
330337 foreign_attributes = set (
331338 attr for p in self .in_edges (name , data = True ) for attr in p [2 ]["attr_map" ] if p [2 ]["primary" ]
332339 )
333340 self .nodes [name ]["distinguished" ] = (
334341 "primary_key" in self .nodes [name ] and foreign_attributes < self .nodes [name ]["primary_key" ]
335342 )
336343 # include aliased nodes that are sandwiched between two displayed nodes
337- gaps = set (nx .algorithms .boundary .node_boundary (self , self . nodes_to_show )).intersection (
338- nx .algorithms .boundary .node_boundary (nx .DiGraph (self ).reverse (), self . nodes_to_show )
344+ gaps = set (nx .algorithms .boundary .node_boundary (self , valid_nodes )).intersection (
345+ nx .algorithms .boundary .node_boundary (nx .DiGraph (self ).reverse (), valid_nodes )
339346 )
340- nodes = self . nodes_to_show .union (a for a in gaps if a .isdigit ())
347+ nodes = valid_nodes .union (a for a in gaps if a .isdigit ())
341348 # construct subgraph and rename nodes to class names
342349 graph = nx .DiGraph (nx .DiGraph (self ).subgraph (nodes ))
343350 nx .set_node_attributes (graph , name = "node_type" , values = {n : _get_tier (n ) for n in graph })
@@ -366,20 +373,24 @@ def _apply_collapse(self, graph: nx.DiGraph) -> tuple[nx.DiGraph, dict[str, str]
366373 tuple[nx.DiGraph, dict[str, str]]
367374 Modified graph and mapping of collapsed schema labels to their table count.
368375 """
369- if not self ._explicit_nodes or self ._explicit_nodes == self .nodes_to_show :
376+ # Filter to valid nodes (those that exist in the underlying graph)
377+ valid_nodes = self .nodes_to_show .intersection (set (self .nodes ()))
378+ valid_explicit = self ._explicit_nodes .intersection (set (self .nodes ()))
379+
380+ if not valid_explicit or valid_explicit == valid_nodes :
370381 # No collapse needed
371382 return graph , {}
372383
373384 # Map full_table_names to class_names
374385 full_to_class = {
375386 node : lookup_class_name (node , self .context ) or node
376- for node in self . nodes_to_show
387+ for node in valid_nodes
377388 }
378389 class_to_full = {v : k for k , v in full_to_class .items ()}
379390
380391 # Identify explicit class names (should be expanded)
381392 explicit_class_names = {
382- full_to_class .get (node , node ) for node in self . _explicit_nodes
393+ full_to_class .get (node , node ) for node in valid_explicit
383394 }
384395
385396 # Identify nodes to collapse (class names)
0 commit comments