Skip to content

Commit 982e4f4

Browse files
committed
Refactor if statement.
1 parent 6caa15b commit 982e4f4

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

research/gam/gam/trainer/trainer_classification_gcn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -871,7 +871,7 @@ def train(self, data, session=None, **kwargs):
871871

872872
def predict(self, session, indices, is_train):
873873
"""Make predictions for the provided sample indices."""
874-
if not indices:
874+
if not indices.shape[0]:
875875
return np.zeros((0, self.data.num_classes), dtype=np.float32)
876876
feed_dict = {
877877
self.input_indices: indices,

0 commit comments

Comments
 (0)