Skip to content

Commit 5312f11

Browse files
Neural-Link Teamtensorflow-copybara
authored andcommitted
Fix a check for validity of input.
PiperOrigin-RevId: 366498374
1 parent bd21fef commit 5312f11

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

research/carls/candidate_sampling/log_uniform_sampler.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,12 @@ absl::Status LogUniformSampler::SampleUnique(
125125
absl::flat_hash_set<absl::string_view> pos_set(positive_keys.begin(),
126126
positive_keys.end());
127127
const size_t range = all_keys.size();
128+
if (num_sampled > range) {
129+
return absl::InvalidArgumentError(
130+
absl::StrCat("Not enough data in the KnolwedgeBank available for "
131+
"unique sampling. Total keys: ",
132+
range, ", num_sampled: ", num_sampled, "."));
133+
}
128134
results->clear();
129135
results->reserve(num_sampled);
130136

@@ -159,12 +165,6 @@ absl::Status LogUniformSampler::SampleUnique(
159165
return absl::OkStatus();
160166
}
161167

162-
if (num_sampled > range) {
163-
return absl::InternalError(
164-
"num_samples is larger than the total number of availabe candidates in "
165-
"the knowledge bank. Potentially caused by the positive keys are not "
166-
"saved to the knowledge bank.");
167-
}
168168
// Case Three: positive_keys.size() < num_sampled < range, sample randomly.
169169
const float prob = static_cast<float>(num_sampled - pos_set.size()) /
170170
static_cast<float>(range - pos_set.size());

0 commit comments

Comments
 (0)