@@ -355,12 +355,15 @@ def _select_samples_to_label(self, data, trainer_cls, session):
355355 assign to each of the selected nodes.
356356 """
357357 # Select the candidate samples for self-labeling, and make predictions.
358- # Remove the validation samples from the unlabeled data, if there, to avoid
359- # self-labeling them.
358+ # We remove the validation and test samples from the unlabeled data,
359+ # to avoid self-labeling them. We could potentially allow them to be
360+ # self-labeled, but once a node is self-labeled its label is fixed for
361+ # the remaining co-train iterations, so it would not take advantage
362+ # of the improved versions of the model.
360363 indices_unlabeled = data .get_indices_unlabeled ()
361- val_ind = set (data .get_indices_val ())
364+ eval_ind = set (data .get_indices_val ()) | set ( data . get_indices_test ())
362365 indices_unlabeled = np .asarray (
363- [ind for ind in indices_unlabeled if ind not in val_ind ])
366+ [ind for ind in indices_unlabeled if ind not in eval_ind ])
364367 predictions = trainer_cls .predict (
365368 session , indices_unlabeled , is_train = False )
366369
@@ -546,8 +549,8 @@ def train(self, data, **kwargs):
546549
547550 # Create a saver which saves only the variables that we would need to
548551 # restore in case the training process is restarted.
549- vars_to_save = [iter_cotrain ] + trainer_agr . vars_to_save + \
550- trainer_cls .vars_to_save
552+ vars_to_save = [iter_cotrain
553+ ] + trainer_agr . vars_to_save + trainer_cls .vars_to_save
551554 saver = tf .train .Saver (vars_to_save )
552555
553556 # Create a TensorFlow session. We allow soft placement in order to place
@@ -633,7 +636,9 @@ def train(self, data, **kwargs):
633636 logging .info (
634637 '--------- Cotrain step %6d | Accuracy val: %10.4f | '
635638 'Accuracy test: %10.4f ---------' , step , val_acc , test_acc )
636-
639+ logging .info (
640+ 'Best validation acc: %.4f, corresponding test acc: %.4f at '
641+ 'iteration %d' , best_val_acc , test_acc_at_best , iter_at_best )
637642 if self .first_iter_original and step == 0 :
638643 logging .info ('No self-labeling because the first iteration trains the '
639644 'original classifier for evaluation purposes.' )
0 commit comments