Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 19 additions & 15 deletions drivers/python/age/networkx/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def getEdgeLabelListAfterPreprocessing(G: nx.DiGraph):


def addAllNodesIntoAGE(connection: psycopg.connect, graphName: str, G: nx.DiGraph, node_label_list: Set):
"""Add all node to AGE"""
"""Add all node to AGE using the unified vertex table"""
try:
queue_data = {label: [] for label in node_label_list}
id_data = {}
Expand All @@ -179,9 +179,11 @@ def addAllNodesIntoAGE(connection: psycopg.connect, graphName: str, G: nx.DiGrap
json_string = json.dumps(data['properties'])
queue_data[data['label']].append((json_string,))

unified_table = f"{graphName}._ag_label_vertex"
for label, rows in queue_data.items():
table_name = """%s."%s" """ % (graphName, label)
insert_query = f"INSERT INTO {table_name} (properties) VALUES (%s) RETURNING id"
# Get the label table OID for the labels column
label_table = f'{graphName}."{label}"'
insert_query = f"INSERT INTO {unified_table} (properties, labels) VALUES (%s, '{label_table}'::regclass::oid) RETURNING id"
cursor = connection.cursor()
cursor.executemany(insert_query, rows, returning=True)
ids = []
Expand Down Expand Up @@ -224,19 +226,21 @@ def addAllEdgesIntoAGE(connection: psycopg.connect, graphName: str, G: nx.DiGrap


def addAllNodesIntoNetworkx(connection: psycopg.connect, graphName: str, G: nx.DiGraph):
"""Add all nodes to Networkx"""
node_label_list = get_vlabel(connection, graphName)
"""Add all nodes to Networkx from the unified vertex table"""
try:
for label in node_label_list:
with connection.cursor() as cursor:
cursor.execute("""
SELECT id, CAST(properties AS VARCHAR)
FROM %s."%s";
""" % (graphName, label))
rows = cursor.fetchall()
for row in rows:
G.add_node(int(row[0]), label=label,
properties=json.loads(row[1]))
with connection.cursor() as cursor:
# Read all vertices from unified table, getting label from labels column
cursor.execute("""
SELECT id, CAST(properties AS VARCHAR),
ag_catalog._label_name_from_table_oid(labels) as label
FROM %s._ag_label_vertex;
""" % graphName)
rows = cursor.fetchall()
for row in rows:
# Empty string label means default/unlabeled vertex
label = row[2] if row[2] else '_ag_label_vertex'
G.add_node(int(row[0]), label=label,
properties=json.loads(row[1]))
except Exception as e:
print(e)

Expand Down