Skip to content

Commit f17eb3a

Browse files
Merge pull request #82 from otiliastr:fix_edges
PiperOrigin-RevId: 368331006
2 parents 01bb1fb + 52a64ca commit f17eb3a

File tree

3 files changed

+19
-9
lines changed

3 files changed

+19
-9
lines changed

research/gam/gam/data/dataset.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,12 +253,19 @@ class GraphDataset(Dataset):
253253
"""Data container for SSL datasets."""
254254

255255
class Edge(object):
256+
"""Graph Edge."""
256257

257258
def __init__(self, src, tgt, weight=None):
258259
self.src = src
259260
self.tgt = tgt
260261
self.weight = weight
261262

263+
def copy(self, src=None, tgt=None, weight=None):
264+
src = src if src is not None else self.src
265+
tgt = tgt if tgt is not None else self.tgt
266+
weight = weight if weight is not None else self.weight
267+
return GraphDataset.Edge(src, tgt, weight)
268+
262269
def __init__(self,
263270
name,
264271
features,
@@ -879,9 +886,8 @@ def restore_state_from_file(self, path):
879886
indices_type = (
880887
np.uint32
881888
if self.dataset.num_samples < np.iinfo(np.uint32).max else np.uint64)
882-
if os.path.exists(file_indices_train) \
883-
and os.path.exists(file_indices_unlabeled) \
884-
and os.path.exists(file_labels_train):
889+
if os.path.exists(file_indices_train) and os.path.exists(
890+
file_indices_unlabeled) and os.path.exists(file_labels_train):
885891
with open(file_indices_train, 'r') as f:
886892
indices_train = np.genfromtxt(f, delimiter=',', dtype=indices_type)
887893
with open(file_indices_unlabeled, 'r') as f:

research/gam/gam/trainer/trainer_classification.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -661,9 +661,11 @@ def edge_iterator(self, data, batch_size, labeling):
661661
edges = data.get_edges(
662662
src_labeled=True, tgt_labeled=True, label_must_match=True)
663663
elif labeling == 'lu':
664-
edges = (
665-
data.get_edges(src_labeled=True, tgt_labeled=False) +
666-
data.get_edges(src_labeled=False, tgt_labeled=True))
664+
edges_lu = data.get_edges(src_labeled=True, tgt_labeled=False)
665+
edges_ul = data.get_edges(src_labeled=False, tgt_labeled=True)
666+
# Reverse the edges of UL to be LU.
667+
edges_ul = [e.copy(src=e.tgt, tgt=e.src) for e in edges_ul]
668+
edges = edges_lu + edges_ul
667669
elif labeling == 'uu':
668670
edges = data.get_edges(src_labeled=False, tgt_labeled=False)
669671
else:

research/gam/gam/trainer/trainer_classification_gcn.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -689,9 +689,11 @@ def edge_iterator(self, data, batch_size, labeling):
689689
edges = data.get_edges(
690690
src_labeled=True, tgt_labeled=True, label_must_match=True)
691691
elif labeling == 'lu':
692-
edges = (
693-
data.get_edges(src_labeled=True, tgt_labeled=False) +
694-
data.get_edges(src_labeled=False, tgt_labeled=True))
692+
edges_lu = data.get_edges(src_labeled=True, tgt_labeled=False)
693+
edges_ul = data.get_edges(src_labeled=False, tgt_labeled=True)
694+
# Reverse the edges of UL to be LU.
695+
edges_ul = [e.copy(src=e.tgt, tgt=e.src) for e in edges_ul]
696+
edges = edges_lu + edges_ul
695697
elif labeling == 'uu':
696698
edges = data.get_edges(src_labeled=False, tgt_labeled=False)
697699
else:

0 commit comments

Comments
 (0)