Skip to content

Commit eeeb95b

Browse files
Neural-Link Teamtensorflow-copybara
authored andcommitted
Enable candidate sampling in DynamicEmbeddingManager.
PiperOrigin-RevId: 367746852
1 parent 0544f9e commit eeeb95b

File tree

7 files changed

+504
-13
lines changed

7 files changed

+504
-13
lines changed

research/carls/BUILD

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ cc_library(
158158
"//research/carls/knowledge_bank",
159159
"//research/carls/knowledge_bank:in_proto_knowledge_bank",
160160
"@com_github_grpc_grpc//:grpc++",
161+
"@com_google_absl//absl/container:flat_hash_set",
161162
"@com_google_absl//absl/strings",
162163
"@com_google_absl//absl/synchronization",
163164
],
@@ -245,8 +246,11 @@ cc_test(
245246
":dynamic_embedding_manager",
246247
":kbs_server_helper_lib",
247248
"//research/carls/base:proto_helper",
249+
"//research/carls/candidate_sampling:candidate_sampler_config_cc_proto",
250+
"//research/carls/testing:test_helper",
248251
"@com_github_grpc_grpc//:grpc++",
249252
"@com_google_googletest//:gtest_main",
253+
"@tensorflow_solib//:framework_lib",
250254
],
251255
)
252256

research/carls/dynamic_embedding_manager.cc

Lines changed: 154 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ ABSL_FLAG(double, kbs_rpc_deadline_sec, 10,
3333
namespace carls {
3434
namespace {
3535

36+
using ::tensorflow::Tensor;
3637
using ::tensorflow::tstring;
3738

3839
#ifndef INTERNAL_DIE_IF_NULL
@@ -106,9 +107,8 @@ DynamicEmbeddingManager::DynamicEmbeddingManager(
106107
config_(config),
107108
session_handle_(session_handle) {}
108109

109-
absl::Status DynamicEmbeddingManager::Lookup(const tensorflow::Tensor& keys,
110-
bool update,
111-
tensorflow::Tensor* output) {
110+
absl::Status DynamicEmbeddingManager::Lookup(const Tensor& keys, bool update,
111+
Tensor* output) {
112112
CHECK(output != nullptr);
113113
if (!(keys.dims() == 1 || keys.dims() == 2)) {
114114
return absl::InvalidArgumentError(absl::StrCat(
@@ -182,7 +182,7 @@ absl::Status DynamicEmbeddingManager::Lookup(const tensorflow::Tensor& keys,
182182
}
183183

184184
absl::Status DynamicEmbeddingManager::CheckInputForUpdate(
185-
const tensorflow::Tensor& keys, const tensorflow::Tensor& values) {
185+
const Tensor& keys, const Tensor& values) {
186186
if (keys.NumElements() == 0) {
187187
return absl::InvalidArgumentError("Input key is empty.");
188188
}
@@ -203,8 +203,8 @@ absl::Status DynamicEmbeddingManager::CheckInputForUpdate(
203203
return absl::OkStatus();
204204
}
205205

206-
absl::Status DynamicEmbeddingManager::UpdateValues(
207-
const tensorflow::Tensor& keys, const tensorflow::Tensor& values) {
206+
absl::Status DynamicEmbeddingManager::UpdateValues(const Tensor& keys,
207+
const Tensor& values) {
208208
auto status = CheckInputForUpdate(keys, values);
209209
if (!status.ok()) {
210210
return status;
@@ -260,8 +260,8 @@ absl::Status DynamicEmbeddingManager::LookupInternal(
260260
return ToAbslStatus(stub_->Lookup(&context, request, response));
261261
}
262262

263-
absl::Status DynamicEmbeddingManager::UpdateGradients(
264-
const tensorflow::Tensor& keys, const tensorflow::Tensor& grads) {
263+
absl::Status DynamicEmbeddingManager::UpdateGradients(const Tensor& keys,
264+
const Tensor& grads) {
265265
auto status = CheckInputForUpdate(keys, grads);
266266
if (!status.ok()) {
267267
return status;
@@ -300,6 +300,152 @@ absl::Status DynamicEmbeddingManager::UpdateGradients(
300300
stub_->Update(&context, update_request, &update_response));
301301
}
302302

303+
absl::Status DynamicEmbeddingManager::NegativeSamplingWithLogits(
304+
const Tensor& positive_keys, const Tensor& input_activations,
305+
const int num_samples, const bool update, Tensor* output_keys,
306+
Tensor* output_logits, Tensor* output_labels,
307+
Tensor* output_expected_counts, Tensor* output_masks,
308+
Tensor* output_embeddings) {
309+
RET_CHECK_TRUE(config_.embedding_dimension() > 0)
310+
<< "Invalid embedding dimension:" << config_.embedding_dimension();
311+
RET_CHECK_TRUE(num_samples > 0);
312+
313+
// Shape of input: [d1, d2, ..., inner_dim].
314+
const int dims = input_activations.dims();
315+
const int inner_dim = input_activations.dim_size(dims - 1);
316+
RET_CHECK_TRUE(inner_dim == config_.embedding_dimension());
317+
const int batch_size =
318+
input_activations.NumElements() / config_.embedding_dimension();
319+
320+
// Processes positive keys.
321+
SampleRequest sample_request;
322+
sample_request.set_session_handle(session_handle_);
323+
sample_request.set_num_samples(num_samples);
324+
sample_request.set_update(update);
325+
const auto pos_key_values = positive_keys.flat_inner_dims<tstring>();
326+
for (int b = 0; b < batch_size; ++b) {
327+
auto* sample_context = sample_request.add_sample_context();
328+
for (int i = 0; i < positive_keys.dim_size(1); ++i) {
329+
if (!pos_key_values(b, i).empty()) {
330+
sample_context->add_positive_key(std::string(pos_key_values(b, i)));
331+
}
332+
}
333+
}
334+
335+
// Calls the Sample RPC.
336+
grpc::ClientContext context;
337+
context.set_deadline(std::chrono::system_clock::now() +
338+
absl::ToChronoSeconds(absl::Seconds(
339+
absl::GetFlag(FLAGS_kbs_rpc_deadline_sec))));
340+
SampleResponse sample_response;
341+
RET_CHECK_OK(stub_->Sample(&context, sample_request, &sample_response));
342+
RET_CHECK_TRUE(sample_response.samples_size() == batch_size);
343+
344+
// Process sampled results.
345+
auto output_keys_values = output_keys->flat_inner_dims<tstring>();
346+
auto logits_values = output_logits->flat_inner_dims<float>();
347+
auto label_values = output_labels->flat_inner_dims<float>();
348+
auto expected_count_values = output_expected_counts->flat_inner_dims<float>();
349+
auto mask_values = output_masks->flat<float>();
350+
auto embedding_values = output_embeddings->flat_inner_dims<float, 3>();
351+
auto input_values = input_activations.flat_inner_dims<float>();
352+
for (int b = 0; b < batch_size; ++b) {
353+
// Use auto& such that we can directly move some contents of samples into
354+
// the output for efficiency.
355+
auto& samples = *sample_response.mutable_samples(b);
356+
357+
// If no sample result is returned, set the default values for output
358+
// tensors.
359+
if (samples.sampled_result().empty()) {
360+
mask_values(b) = 0.0f;
361+
for (int i = 0; i < num_samples; ++i) {
362+
logits_values(b, i) = 0.0f;
363+
output_keys_values(b, i) = "";
364+
label_values(b, i) = 0;
365+
expected_count_values(b, i) = 1;
366+
for (int d = 0; d < config_.embedding_dimension(); ++d) {
367+
embedding_values(b, i, d) = 0.0f;
368+
}
369+
}
370+
continue;
371+
}
372+
mask_values(b) = 1.0f;
373+
374+
// Processes the output tensors.
375+
RET_CHECK_TRUE(samples.sampled_result_size() == num_samples);
376+
RET_CHECK_TRUE(samples.sampled_result(0).has_negative_sampling_result());
377+
for (int i = 0; i < samples.sampled_result_size(); ++i) {
378+
auto& result = *samples.mutable_sampled_result(i)
379+
->mutable_negative_sampling_result();
380+
const auto& embedding = result.embedding();
381+
logits_values(b, i) = 0.0;
382+
label_values(b, i) = result.is_positive() ? 1.0 : 0.0;
383+
expected_count_values(b, i) = result.expected_count();
384+
output_keys_values(b, i) = std::move(result.key());
385+
float logit_value = 0; // Computes the dot product.
386+
for (int d = 0; d < config_.embedding_dimension(); ++d) {
387+
embedding_values(b, i, d) = embedding.value(d);
388+
// Computes the logits_values based on returned embedding values and
389+
// input activations.
390+
logit_value += input_values(b, d) * embedding.value(d);
391+
}
392+
logits_values(b, i) = logit_value;
393+
}
394+
}
395+
396+
return absl::OkStatus();
397+
}
398+
399+
absl::Status DynamicEmbeddingManager::TopK(
400+
const tensorflow::Tensor& input_activations, const int k,
401+
tensorflow::Tensor* output_keys, tensorflow::Tensor* output_logits) {
402+
RET_CHECK_TRUE(config_.embedding_dimension() > 0)
403+
<< "Invalid embedding dimension:" << config_.embedding_dimension();
404+
RET_CHECK_TRUE(k > 0);
405+
406+
// Shape of input: batch_size x hidden_size.
407+
const int dims = input_activations.dims();
408+
const int inner_dim = input_activations.dim_size(dims - 1);
409+
RET_CHECK_TRUE(inner_dim == config_.embedding_dimension());
410+
const int batch_size =
411+
input_activations.NumElements() / config_.embedding_dimension();
412+
413+
// Processes SampleRequest.
414+
SampleRequest sample_request;
415+
sample_request.set_session_handle(session_handle_);
416+
sample_request.set_num_samples(k);
417+
auto activation_value = input_activations.flat_inner_dims<float>();
418+
for (int b = 0; b < batch_size; ++b) {
419+
auto* sample_context = sample_request.add_sample_context();
420+
for (int i = 0; i < config_.embedding_dimension(); ++i) {
421+
sample_context->mutable_activation()->add_value(activation_value(b, i));
422+
}
423+
}
424+
425+
// Calls the Sample RPC.
426+
grpc::ClientContext context;
427+
context.set_deadline(std::chrono::system_clock::now() +
428+
absl::ToChronoSeconds(absl::Seconds(
429+
absl::GetFlag(FLAGS_kbs_rpc_deadline_sec))));
430+
SampleResponse sample_response;
431+
RET_CHECK_OK(stub_->Sample(&context, sample_request, &sample_response));
432+
RET_CHECK_TRUE(sample_response.samples_size() == batch_size);
433+
434+
// Process topk results.
435+
auto output_keys_values = output_keys->flat_inner_dims<tstring>();
436+
auto logits_values = output_logits->flat_inner_dims<float>();
437+
for (int b = 0; b < batch_size; ++b) {
438+
const auto& samples = sample_response.samples(b);
439+
RET_CHECK_TRUE(samples.sampled_result_size() == k);
440+
for (int i = 0; i < k; ++i) {
441+
auto& result = samples.sampled_result(i).topk_sampling_result();
442+
logits_values(b, i) = result.similarity();
443+
output_keys_values(b, i) = std::move(result.key());
444+
}
445+
}
446+
return absl::OkStatus();
447+
}
448+
303449
absl::Status DynamicEmbeddingManager::Export(const std::string& output_dir,
304450
std::string* exported_path) {
305451
CHECK(exported_path != nullptr);

research/carls/dynamic_embedding_manager.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,44 @@ class DynamicEmbeddingManager {
6161
// Returns DynamicEmbeddingConfig.
6262
const DynamicEmbeddingConfig& config() { return config_; }
6363

64+
// Samples negative keys from given positive keys and compute the dot products
65+
// between the embeddings of the positive/negative keys and the input
66+
// activations.
67+
//
68+
// If update = true, new embeddings are dynamically allocated for new
69+
// positive keys, which is often used in training.
70+
//
71+
// Note that for a logit layer with activation x in the last layer, one needs
72+
// to append an extra 1 to the input activations to obtain wx + b, where [w,
73+
// b] is the embedding of a particular output key.
74+
//
75+
// The `output_labels` indicates if the corresponding `output_keys` is a
76+
// positive or negative sample, and the `output_expected_counts` represents
77+
// the sampling probability. Please refer to
78+
// carls.candidate_sampling.NegativeSamplingResult for details.
79+
//
80+
// `output_mask` indicates whether `positive_keys` of an entry in the input
81+
// batch are all invalid (empty).
82+
//
83+
// `output_embedding` returns the embeddings of the sampled keys. It should
84+
// be allocated as [batch_size, num_samples, embed_dim]. This is needed for
85+
// computing the gradients w.r.t. the input_activations.
86+
absl::Status NegativeSamplingWithLogits(
87+
const tensorflow::Tensor& positive_keys,
88+
const tensorflow::Tensor& input_activations, int num_samples, bool update,
89+
tensorflow::Tensor* output_keys, tensorflow::Tensor* output_logits,
90+
tensorflow::Tensor* output_labels,
91+
tensorflow::Tensor* output_expected_counts,
92+
tensorflow::Tensor* output_masks, tensorflow::Tensor* output_embeddings);
93+
94+
// Return top k closest embeddings to each of the input activations.
95+
// Note that for a logit layer with activation x, one need to append an extra
96+
// 1 to the input activations to obtain wx + b, where [w, b] is the embedding
97+
// of a particular output key.
98+
absl::Status TopK(const tensorflow::Tensor& input_activations, int k,
99+
tensorflow::Tensor* output_keys,
100+
tensorflow::Tensor* output_logits);
101+
64102
// Calls the KnowledgeBankService::Export RPC.
65103
absl::Status Export(const std::string& output_dir,
66104
std::string* exported_path);

0 commit comments

Comments
 (0)