Skip to content

Commit

Permalink
Fixed train/test_size calculation.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 337886488
  • Loading branch information
dvadym authored and tensorflower-gardener committed Oct 19, 2020
1 parent 19ae5c9 commit 4143957
Showing 1 changed file with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,12 @@ def num_classes(self):
@property
def logits_or_probs_train(self):
"""Returns train logits or probs whatever is not None."""
return self.logits_train if self.probs_train is None else self.probs_train
return self.probs_train or self.logits_train

@property
def logits_or_probs_test(self):
"""Returns test logits or probs whatever is not None."""
return self.logits_test if self.probs_test is None else self.probs_test
return self.probs_test or self.logits_test

@staticmethod
def _get_entropy(logits: np.ndarray, true_labels: np.ndarray):
Expand Down Expand Up @@ -278,13 +278,13 @@ def get_train_size(self):
"""Returns size of the training set."""
if self.loss_train is not None:
return self.loss_train.size
return self.logits_train.shape[0]
return self.logits_or_probs_train.shape[0]

def get_test_size(self):
"""Returns size of the test set."""
if self.loss_test is not None:
return self.loss_test.size
return self.logits_test.shape[0]
return self.logits_or_probs_test.shape[0]

def validate(self):
"""Validates the inputs."""
Expand Down

0 comments on commit 4143957

Please sign in to comment.