Skip to content

Commit bd21fef

Browse files
Neural-Link Teamtensorflow-copybara
authored andcommitted
Add a candidate sampling component.
PiperOrigin-RevId: 366483052
1 parent 16ab2ff commit bd21fef

File tree

12 files changed

+1196
-1
lines changed

12 files changed

+1196
-1
lines changed
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Description:
16+
# Build rules for the candidate sampling component of CARLS.
17+
18+
load("//research/carls:bazel/build_rules.bzl", "carls_cc_proto_library", "carls_py_proto_library")
19+
20+
package(
21+
default_visibility = [
22+
":internal",
23+
],
24+
licenses = ["notice"], # Apache 2.0
25+
)
26+
27+
package_group(
28+
name = "internal",
29+
packages = [
30+
"//research/...",
31+
],
32+
)
33+
34+
carls_cc_proto_library(
35+
name = "candidate_sampler_config_cc_proto",
36+
srcs = ["candidate_sampler_config.proto"],
37+
deps = [
38+
"//research/carls:embedding_cc_proto",
39+
],
40+
)
41+
42+
carls_py_proto_library(
43+
name = "candidate_sampler_config_py_pb2",
44+
srcs = ["candidate_sampler_config.proto"],
45+
deps = [
46+
":candidate_sampler_config_cc_proto",
47+
"//research/carls:embedding_py_pb2",
48+
],
49+
)
50+
51+
cc_library(
52+
name = "candidate_sampler",
53+
srcs = ["candidate_sampler.cc"],
54+
hdrs = ["candidate_sampler.h"],
55+
deps = [
56+
":candidate_sampler_config_cc_proto",
57+
"//research/carls/base:proto_factory",
58+
"//research/carls/knowledge_bank",
59+
"@com_google_absl//absl/status",
60+
"@com_google_absl//absl/strings",
61+
],
62+
)
63+
64+
cc_test(
65+
name = "candidate_sampler_test",
66+
srcs = ["candidate_sampler_test.cc"],
67+
deps = [
68+
":candidate_sampler",
69+
"@com_google_absl//absl/status",
70+
"@com_google_googletest//:gtest_main",
71+
],
72+
)
73+
74+
cc_library(
75+
name = "log_uniform_sampler",
76+
srcs = ["log_uniform_sampler.cc"],
77+
deps = [
78+
":candidate_sampler",
79+
"//research/carls/base:proto_helper",
80+
"@com_google_absl//absl/container:flat_hash_set",
81+
"@com_google_absl//absl/status",
82+
"@tensorflow_includes//:includes",
83+
"@tensorflow_solib//:framework_lib",
84+
],
85+
alwayslink = 1,
86+
)
87+
88+
cc_test(
89+
name = "log_uniform_sampler_test",
90+
srcs = ["log_uniform_sampler_test.cc"],
91+
deps = [
92+
":candidate_sampler",
93+
":log_uniform_sampler",
94+
"//research/carls/knowledge_bank:initializer_cc_proto",
95+
"//research/carls/testing:test_helper",
96+
"@com_google_googletest//:gtest_main",
97+
],
98+
)
99+
100+
cc_library(
101+
name = "brute_force_topk_sampler",
102+
srcs = ["brute_force_topk_sampler.cc"],
103+
deps = [
104+
":candidate_sampler",
105+
":candidate_sampler_config_cc_proto",
106+
"//research/carls:embedding_cc_proto",
107+
"//research/carls/base:embedding_helper",
108+
"//research/carls/base:top_n",
109+
"@com_google_absl//absl/status",
110+
],
111+
alwayslink = 1,
112+
)
113+
114+
cc_test(
115+
name = "brute_force_topk_sampler_test",
116+
srcs = ["brute_force_topk_sampler_test.cc"],
117+
deps = [
118+
":brute_force_topk_sampler",
119+
":candidate_sampler",
120+
"//research/carls:embedding_cc_proto",
121+
"//research/carls/knowledge_bank:initializer_cc_proto",
122+
"//research/carls/knowledge_bank:initializer_helper",
123+
"//research/carls/testing:test_helper",
124+
"@com_google_googletest//:gtest_main",
125+
],
126+
)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Candidate Sampling
2+
3+
The candidate sampling component of CARLS is responsible for sampling data from
4+
a knowledge bank based on different application. For example
5+
6+
* Sampled Softmax/Logistic for efficient loss computation
7+
([TensorFlow Doc](https://www.tensorflow.org/extras/candidate_sampling.pdf)):
8+
this is useful when the target class in a model is very large or highly
9+
dynamic.
10+
11+
* Top-K/Nearest Neighbors: find the top-k closest target classses in a
12+
softmax/logistic top layer during model inference.
13+
14+
* Attention-based on external knowledge retrieval: retrieve an attention
15+
vector based on given input query from a sampled large set of values.
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
/*Copyright 2020 Google LLC
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include <string>
17+
#include <vector>
18+
19+
#include "absl/status/status.h"
20+
#include "research/carls/base/embedding_helper.h"
21+
#include "research/carls/base/top_n.h"
22+
#include "research/carls/candidate_sampling/candidate_sampler.h"
23+
#include "research/carls/candidate_sampling/candidate_sampler_config.pb.h" // proto to pb
24+
#include "research/carls/embedding.pb.h" // proto to pb
25+
26+
namespace carls {
27+
namespace candidate_sampling {
28+
namespace {
29+
30+
// Represents an embedding in the knowledge bank, used for top-k comparison.
31+
struct CandidateInfo {
32+
// The key in the knowledge bank.
33+
absl::string_view key;
34+
35+
// The similarity between the activation and the embedding of the `key`.
36+
float similarity = 0;
37+
38+
// The embedding of the key. Make a copy to avoid accidental deallocation.
39+
EmbeddingVectorProto embed;
40+
41+
CandidateInfo(absl::string_view k, float s) : key(k), similarity(s) {}
42+
};
43+
44+
// Used for the top-k computation.
45+
struct CandidateInfoComparator {
46+
bool operator()(const CandidateInfo& lhs, const CandidateInfo& rhs) const {
47+
return lhs.similarity > rhs.similarity;
48+
}
49+
};
50+
51+
BruteForceTopkSamplerConfig GetTopkConfig(
52+
const CandidateSamplerConfig& config) {
53+
return GetExtensionProtoOrDie<CandidateSamplerConfig,
54+
BruteForceTopkSamplerConfig>(config);
55+
}
56+
57+
} // namespace
58+
59+
// A brute-force implementation of the top-k sampler. Each time the Sample()
60+
// method is called, it traverses all the embeddings in a knowledge bank and
61+
// compares their similarities with given input activation.
62+
class BruteForceTopkSampler : public CandidateSampler {
63+
public:
64+
BruteForceTopkSampler(const CandidateSamplerConfig& config)
65+
: CandidateSampler(config), topk_config_(GetTopkConfig(config)) {}
66+
67+
private:
68+
absl::Status SampleInternal(
69+
const KnowledgeBank& knowledge_bank, const SampleContext& sample_context,
70+
int num_samples, std::vector<SampledResult>* results) const override;
71+
72+
const BruteForceTopkSamplerConfig topk_config_;
73+
};
74+
75+
REGISTER_SAMPLER_FACTORY(BruteForceTopkSamplerConfig,
76+
[](const CandidateSamplerConfig& config)
77+
-> std::unique_ptr<CandidateSampler> {
78+
auto topk_config = GetTopkConfig(config);
79+
if (topk_config.similarity_type() == UNKNOWN) {
80+
LOG(ERROR)
81+
<< "Unknown similarity type, cannot create "
82+
"BruteForceTopkSampler.";
83+
return nullptr;
84+
}
85+
return std::unique_ptr<CandidateSampler>(
86+
new BruteForceTopkSampler(config));
87+
});
88+
89+
absl::Status BruteForceTopkSampler::SampleInternal(
90+
const KnowledgeBank& knowledge_bank, const SampleContext& sample_context,
91+
int num_samples, std::vector<SampledResult>* results) const {
92+
if (!sample_context.has_activation()) {
93+
return absl::InvalidArgumentError("No activation from sample_context.");
94+
}
95+
if (knowledge_bank.embedding_dimension() !=
96+
sample_context.activation().value_size()) {
97+
return absl::InvalidArgumentError(
98+
absl::StrCat("Invalid embedding dimension from activation, expect ",
99+
knowledge_bank.embedding_dimension(), ", got ",
100+
sample_context.activation().value_size(), "."));
101+
}
102+
103+
std::vector<absl::string_view> all_keys = knowledge_bank.Keys();
104+
TopN<CandidateInfo, CandidateInfoComparator> topn(num_samples);
105+
for (auto key : all_keys) {
106+
EmbeddingVectorProto embed;
107+
if (!knowledge_bank.Lookup(key, &embed).ok()) {
108+
continue;
109+
}
110+
if (knowledge_bank.embedding_dimension() != embed.value_size()) {
111+
return absl::InternalError(absl::StrCat(
112+
"Inconsistent embedding size (", embed.value_size(), " v.s. ",
113+
sample_context.activation().value_size(), ") for key: ", key));
114+
}
115+
float similarity = 0;
116+
switch (topk_config_.similarity_type()) {
117+
case DOT_PRODUCT:
118+
if (ComputeDotProduct(sample_context.activation(), embed,
119+
&similarity)) {
120+
CandidateInfo candidate_info(key, similarity);
121+
candidate_info.embed = std::move(embed);
122+
topn.push(std::move(candidate_info));
123+
}
124+
break;
125+
case COSINE:
126+
if (ComputeCosineSimilarity(sample_context.activation(), embed,
127+
&similarity)) {
128+
CandidateInfo candidate_info(key, similarity);
129+
candidate_info.embed = std::move(embed);
130+
topn.push(std::move(candidate_info));
131+
}
132+
break;
133+
default:
134+
LOG(FATAL) << "Shouldn't be here. Similarity type: "
135+
<< topk_config_.similarity_type();
136+
}
137+
}
138+
139+
// Processes results.
140+
results->clear();
141+
results->reserve(num_samples);
142+
std::unique_ptr<std::vector<CandidateInfo>> topn_results(topn.Extract());
143+
for (auto& candidate_info : *topn_results) {
144+
SampledResult sampled_result;
145+
TopkSamplingResult* result = sampled_result.mutable_topk_sampling_result();
146+
result->set_key(std::string(candidate_info.key));
147+
result->set_similarity(candidate_info.similarity);
148+
*(result->mutable_embedding()) = std::move(candidate_info.embed);
149+
results->push_back(std::move(sampled_result));
150+
}
151+
return absl::OkStatus();
152+
}
153+
154+
} // namespace candidate_sampling
155+
} // namespace carls

0 commit comments

Comments
 (0)