Skip to content

Commit 7ea3afa

Browse files
Neural-Link Teamtensorflow-copybara
authored andcommitted
Enable Candidate Sampling Part II: add Sample RPC
PiperOrigin-RevId: 366610296
1 parent 5312f11 commit 7ea3afa

File tree

6 files changed

+258
-51
lines changed

6 files changed

+258
-51
lines changed

research/carls/BUILD

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ carls_cc_proto_library(
6666
name = "dynamic_embedding_config_cc_proto",
6767
srcs = ["dynamic_embedding_config.proto"],
6868
deps = [
69+
"//research/carls/candidate_sampling:candidate_sampler_config_cc_proto",
6970
"//research/carls/gradient_descent:gradient_descent_config_cc_proto",
7071
"//research/carls/knowledge_bank:knowledge_bank_config_cc_proto",
7172
],
@@ -76,6 +77,7 @@ carls_py_proto_library(
7677
srcs = ["dynamic_embedding_config.proto"],
7778
deps = [
7879
":dynamic_embedding_config_cc_proto",
80+
"//research/carls/candidate_sampling:candidate_sampler_config_py_pb2",
7981
"//research/carls/gradient_descent:gradient_descent_config_py_pb2",
8082
"//research/carls/knowledge_bank:knowledge_bank_config_py_pb2",
8183
],
@@ -87,6 +89,7 @@ carls_cc_proto_library(
8789
deps = [
8890
":dynamic_embedding_config_cc_proto",
8991
":embedding_cc_proto",
92+
"//research/carls/candidate_sampling:candidate_sampler_config_cc_proto",
9093
],
9194
)
9295

@@ -148,6 +151,9 @@ cc_library(
148151
hdrs = ["knowledge_bank_grpc_service.h"],
149152
deps = [
150153
":knowledge_bank_service_cc_grpc_proto",
154+
"//research/carls/candidate_sampling:brute_force_topk_sampler",
155+
"//research/carls/candidate_sampling:candidate_sampler",
156+
"//research/carls/candidate_sampling:log_uniform_sampler",
151157
"//research/carls/gradient_descent:gradient_descent_optimizer",
152158
"//research/carls/knowledge_bank",
153159
"//research/carls/knowledge_bank:in_proto_knowledge_bank",

research/carls/dynamic_embedding_config.proto

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ syntax = "proto3";
22

33
package carls;
44

5+
import "research/carls/candidate_sampling/candidate_sampler_config.proto";
56
import "research/carls/gradient_descent/gradient_descent_config.proto";
67
import "research/carls/knowledge_bank/knowledge_bank_config.proto";
78

@@ -14,4 +15,6 @@ message DynamicEmbeddingConfig {
1415
KnowledgeBankConfig knowledge_bank_config = 2;
1516

1617
GradientDescentConfig gradient_descent_config = 3;
18+
19+
candidate_sampling.CandidateSamplerConfig candidate_sampler_config = 4;
1720
}

research/carls/knowledge_bank_grpc_service.cc

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,37 @@ Status KnowledgeBankGrpcServiceImpl::Update(grpc::ServerContext* context,
176176
return Status::OK;
177177
}
178178

179+
grpc::Status KnowledgeBankGrpcServiceImpl::Sample(grpc::ServerContext* context,
180+
const SampleRequest* request,
181+
SampleResponse* response) {
182+
if (request->session_handle().empty()) {
183+
return Status(StatusCode::INVALID_ARGUMENT, "session_handle is empty.");
184+
}
185+
if (request->sample_context().empty()) {
186+
return Status(StatusCode::INVALID_ARGUMENT, "No sample context.");
187+
}
188+
const auto status = StartSessionIfNecessary(request->session_handle());
189+
if (!status.ok()) {
190+
return status;
191+
}
192+
absl::MutexLock lock(&map_mu_);
193+
const auto& knowledge_bank = *kb_map_[request->session_handle()];
194+
195+
for (const auto& sample_context : request->sample_context()) {
196+
std::vector<candidate_sampling::SampledResult> results;
197+
auto status = cs_map_[request->session_handle()]->Sample(
198+
knowledge_bank, sample_context, request->num_samples(), &results);
199+
if (!status.ok()) {
200+
return ToGrpcStatus(status);
201+
}
202+
auto* samples = response->add_samples();
203+
for (auto& result : results) {
204+
*samples->add_sampled_result() = std::move(result);
205+
}
206+
}
207+
return Status::OK;
208+
}
209+
179210
Status KnowledgeBankGrpcServiceImpl::Export(grpc::ServerContext* context,
180211
const ExportRequest* request,
181212
ExportResponse* response) {
@@ -247,6 +278,15 @@ Status KnowledgeBankGrpcServiceImpl::StartSessionIfNecessary(
247278
}
248279
gd_map_[session_handle] = std::move(optimizer);
249280
}
281+
if (request.config().has_candidate_sampler_config() &&
282+
!cs_map_.contains(session_handle)) {
283+
auto sampler = candidate_sampling::SamplerFactory::Make(
284+
request.config().candidate_sampler_config());
285+
if (sampler == nullptr) {
286+
return Status(StatusCode::INTERNAL, "Creating CandidateSampler failed.");
287+
}
288+
cs_map_[session_handle] = std::move(sampler);
289+
}
250290
return Status::OK;
251291
}
252292

research/carls/knowledge_bank_grpc_service.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License.
2020

2121
#include "grpcpp/support/status.h" // net
2222
#include "absl/synchronization/mutex.h"
23+
#include "research/carls/candidate_sampling/candidate_sampler.h"
2324
#include "research/carls/gradient_descent/gradient_descent_optimizer.h"
2425
#include "research/carls/knowledge_bank/knowledge_bank.h"
2526
#include "research/carls/knowledge_bank_service.grpc.pb.h"
@@ -53,6 +54,11 @@ class KnowledgeBankGrpcServiceImpl final
5354
const UpdateRequest* request,
5455
UpdateResponse* response) override;
5556

57+
// Implements the Sample method of KnowledgeBankService.
58+
grpc::Status Sample(grpc::ServerContext* context,
59+
const SampleRequest* request,
60+
SampleResponse* response) override;
61+
5662
// Implements the Export method of KnowledgeBankService.
5763
grpc::Status Export(grpc::ServerContext* context,
5864
const ExportRequest* request,
@@ -77,6 +83,10 @@ class KnowledgeBankGrpcServiceImpl final
7783
// Maps from session_handle to GradientDescentOptimizer.
7884
absl::node_hash_map<std::string, std::unique_ptr<GradientDescentOptimizer>>
7985
gd_map_;
86+
// Maps from session_handle to CandidateSampler.
87+
absl::node_hash_map<std::string,
88+
std::unique_ptr<candidate_sampling::CandidateSampler>>
89+
cs_map_;
8090
};
8191

8292
} // namespace carls

0 commit comments

Comments
 (0)