Skip to content

Commit 6caa15b

Browse files
committed
Avoiding self-labeling the test nodes, to allow them to take advantage of the model trained in the latter co-training steps.
1 parent 4a574b8 commit 6caa15b

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

research/gam/gam/trainer/trainer_cotrain.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -355,12 +355,12 @@ def _select_samples_to_label(self, data, trainer_cls, session):
355355
assign to each of the selected nodes.
356356
"""
357357
# Select the candidate samples for self-labeling, and make predictions.
358-
# Remove the validation samples from the unlabeled data, if there, to avoid
359-
# self-labeling them.
358+
# Remove the validation and test samples from the unlabeled data, if there,
359+
# to avoid self-labeling them.
360360
indices_unlabeled = data.get_indices_unlabeled()
361-
val_ind = set(data.get_indices_val())
361+
eval_ind = set(data.get_indices_val()) | set(data.get_indices_test())
362362
indices_unlabeled = np.asarray(
363-
[ind for ind in indices_unlabeled if ind not in val_ind])
363+
[ind for ind in indices_unlabeled if ind not in eval_ind])
364364
predictions = trainer_cls.predict(
365365
session, indices_unlabeled, is_train=False)
366366

0 commit comments

Comments
 (0)