Skip to content

Commit 52ea737

Browse files
committed
Fix to edge iterator.
1 parent eeeb95b commit 52ea737

File tree

3 files changed

+16
-6
lines changed

3 files changed

+16
-6
lines changed

research/gam/gam/data/dataset.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,12 @@ def __init__(self, src, tgt, weight=None):
259259
self.tgt = tgt
260260
self.weight = weight
261261

262+
def copy(self, src=None, tgt=None, weight=None):
263+
src = src if src is not None else self.src
264+
tgt = tgt if tgt is not None else self.tgt
265+
weight = weight if weight is not None else self.weight
266+
return GraphDataset.Edge(src, tgt, weight)
267+
262268
def __init__(self,
263269
name,
264270
features,

research/gam/gam/trainer/trainer_classification.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -661,9 +661,11 @@ def edge_iterator(self, data, batch_size, labeling):
661661
edges = data.get_edges(
662662
src_labeled=True, tgt_labeled=True, label_must_match=True)
663663
elif labeling == 'lu':
664-
edges = (
665-
data.get_edges(src_labeled=True, tgt_labeled=False) +
666-
data.get_edges(src_labeled=False, tgt_labeled=True))
664+
edges_lu = data.get_edges(src_labeled=True, tgt_labeled=False)
665+
edges_ul = data.get_edges(src_labeled=False, tgt_labeled=True)
666+
# Reverse the edges of UL to be LU.
667+
edges_ul = [e.copy(src=e.tgt, tgt=e.src) for e in edges_ul]
668+
edges = edges_lu + edges_ul
667669
elif labeling == 'uu':
668670
edges = data.get_edges(src_labeled=False, tgt_labeled=False)
669671
else:

research/gam/gam/trainer/trainer_classification_gcn.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -689,9 +689,11 @@ def edge_iterator(self, data, batch_size, labeling):
689689
edges = data.get_edges(
690690
src_labeled=True, tgt_labeled=True, label_must_match=True)
691691
elif labeling == 'lu':
692-
edges = (
693-
data.get_edges(src_labeled=True, tgt_labeled=False) +
694-
data.get_edges(src_labeled=False, tgt_labeled=True))
692+
edges_lu = data.get_edges(src_labeled=True, tgt_labeled=False)
693+
edges_ul = data.get_edges(src_labeled=False, tgt_labeled=True)
694+
# Reverse the edges of UL to be LU.
695+
edges_ul = [e.copy(src=e.tgt, tgt=e.src) for e in edges_ul]
696+
edges = edges_lu + edges_ul
695697
elif labeling == 'uu':
696698
edges = data.get_edges(src_labeled=False, tgt_labeled=False)
697699
else:

0 commit comments

Comments
 (0)