@@ -33,6 +33,7 @@ ABSL_FLAG(double, kbs_rpc_deadline_sec, 10,
3333namespace carls {
3434namespace {
3535
36+ using ::tensorflow::Tensor;
3637using ::tensorflow::tstring;
3738
3839#ifndef INTERNAL_DIE_IF_NULL
@@ -106,9 +107,8 @@ DynamicEmbeddingManager::DynamicEmbeddingManager(
106107 config_ (config),
107108 session_handle_(session_handle) {}
108109
109- absl::Status DynamicEmbeddingManager::Lookup (const tensorflow::Tensor& keys,
110- bool update,
111- tensorflow::Tensor* output) {
110+ absl::Status DynamicEmbeddingManager::Lookup (const Tensor& keys, bool update,
111+ Tensor* output) {
112112 CHECK (output != nullptr );
113113 if (!(keys.dims () == 1 || keys.dims () == 2 )) {
114114 return absl::InvalidArgumentError (absl::StrCat (
@@ -182,7 +182,7 @@ absl::Status DynamicEmbeddingManager::Lookup(const tensorflow::Tensor& keys,
182182}
183183
184184absl::Status DynamicEmbeddingManager::CheckInputForUpdate (
185- const tensorflow:: Tensor& keys, const tensorflow:: Tensor& values) {
185+ const Tensor& keys, const Tensor& values) {
186186 if (keys.NumElements () == 0 ) {
187187 return absl::InvalidArgumentError (" Input key is empty." );
188188 }
@@ -203,8 +203,8 @@ absl::Status DynamicEmbeddingManager::CheckInputForUpdate(
203203 return absl::OkStatus ();
204204}
205205
206- absl::Status DynamicEmbeddingManager::UpdateValues (
207- const tensorflow::Tensor& keys, const tensorflow:: Tensor& values) {
206+ absl::Status DynamicEmbeddingManager::UpdateValues (const Tensor& keys,
207+ const Tensor& values) {
208208 auto status = CheckInputForUpdate (keys, values);
209209 if (!status.ok ()) {
210210 return status;
@@ -260,8 +260,8 @@ absl::Status DynamicEmbeddingManager::LookupInternal(
260260 return ToAbslStatus (stub_->Lookup (&context, request, response));
261261}
262262
263- absl::Status DynamicEmbeddingManager::UpdateGradients (
264- const tensorflow::Tensor& keys, const tensorflow:: Tensor& grads) {
263+ absl::Status DynamicEmbeddingManager::UpdateGradients (const Tensor& keys,
264+ const Tensor& grads) {
265265 auto status = CheckInputForUpdate (keys, grads);
266266 if (!status.ok ()) {
267267 return status;
@@ -300,6 +300,152 @@ absl::Status DynamicEmbeddingManager::UpdateGradients(
300300 stub_->Update (&context, update_request, &update_response));
301301}
302302
303+ absl::Status DynamicEmbeddingManager::NegativeSamplingWithLogits (
304+ const Tensor& positive_keys, const Tensor& input_activations,
305+ const int num_samples, const bool update, Tensor* output_keys,
306+ Tensor* output_logits, Tensor* output_labels,
307+ Tensor* output_expected_counts, Tensor* output_masks,
308+ Tensor* output_embeddings) {
309+ RET_CHECK_TRUE (config_.embedding_dimension () > 0 )
310+ << " Invalid embedding dimension:" << config_.embedding_dimension ();
311+ RET_CHECK_TRUE (num_samples > 0 );
312+
313+ // Shape of input: [d1, d2, ..., inner_dim].
314+ const int dims = input_activations.dims ();
315+ const int inner_dim = input_activations.dim_size (dims - 1 );
316+ RET_CHECK_TRUE (inner_dim == config_.embedding_dimension ());
317+ const int batch_size =
318+ input_activations.NumElements () / config_.embedding_dimension ();
319+
320+ // Processes positive keys.
321+ SampleRequest sample_request;
322+ sample_request.set_session_handle (session_handle_);
323+ sample_request.set_num_samples (num_samples);
324+ sample_request.set_update (update);
325+ const auto pos_key_values = positive_keys.flat_inner_dims <tstring>();
326+ for (int b = 0 ; b < batch_size; ++b) {
327+ auto * sample_context = sample_request.add_sample_context ();
328+ for (int i = 0 ; i < positive_keys.dim_size (1 ); ++i) {
329+ if (!pos_key_values (b, i).empty ()) {
330+ sample_context->add_positive_key (std::string (pos_key_values (b, i)));
331+ }
332+ }
333+ }
334+
335+ // Calls the Sample RPC.
336+ grpc::ClientContext context;
337+ context.set_deadline (std::chrono::system_clock::now () +
338+ absl::ToChronoSeconds (absl::Seconds (
339+ absl::GetFlag (FLAGS_kbs_rpc_deadline_sec))));
340+ SampleResponse sample_response;
341+ RET_CHECK_OK (stub_->Sample (&context, sample_request, &sample_response));
342+ RET_CHECK_TRUE (sample_response.samples_size () == batch_size);
343+
344+ // Process sampled results.
345+ auto output_keys_values = output_keys->flat_inner_dims <tstring>();
346+ auto logits_values = output_logits->flat_inner_dims <float >();
347+ auto label_values = output_labels->flat_inner_dims <float >();
348+ auto expected_count_values = output_expected_counts->flat_inner_dims <float >();
349+ auto mask_values = output_masks->flat <float >();
350+ auto embedding_values = output_embeddings->flat_inner_dims <float , 3 >();
351+ auto input_values = input_activations.flat_inner_dims <float >();
352+ for (int b = 0 ; b < batch_size; ++b) {
353+ // Use auto& such that we can directly move some contents of samples into
354+ // the output for efficiency.
355+ auto & samples = *sample_response.mutable_samples (b);
356+
357+ // If no sample result is returned, set the default values for output
358+ // tensors.
359+ if (samples.sampled_result ().empty ()) {
360+ mask_values (b) = 0 .0f ;
361+ for (int i = 0 ; i < num_samples; ++i) {
362+ logits_values (b, i) = 0 .0f ;
363+ output_keys_values (b, i) = " " ;
364+ label_values (b, i) = 0 ;
365+ expected_count_values (b, i) = 1 ;
366+ for (int d = 0 ; d < config_.embedding_dimension (); ++d) {
367+ embedding_values (b, i, d) = 0 .0f ;
368+ }
369+ }
370+ continue ;
371+ }
372+ mask_values (b) = 1 .0f ;
373+
374+ // Processes the output tensors.
375+ RET_CHECK_TRUE (samples.sampled_result_size () == num_samples);
376+ RET_CHECK_TRUE (samples.sampled_result (0 ).has_negative_sampling_result ());
377+ for (int i = 0 ; i < samples.sampled_result_size (); ++i) {
378+ auto & result = *samples.mutable_sampled_result (i)
379+ ->mutable_negative_sampling_result ();
380+ const auto & embedding = result.embedding ();
381+ logits_values (b, i) = 0.0 ;
382+ label_values (b, i) = result.is_positive () ? 1.0 : 0.0 ;
383+ expected_count_values (b, i) = result.expected_count ();
384+ output_keys_values (b, i) = std::move (result.key ());
385+ float logit_value = 0 ; // Computes the dot product.
386+ for (int d = 0 ; d < config_.embedding_dimension (); ++d) {
387+ embedding_values (b, i, d) = embedding.value (d);
388+ // Computes the logits_values based on returned embedding values and
389+ // input activations.
390+ logit_value += input_values (b, d) * embedding.value (d);
391+ }
392+ logits_values (b, i) = logit_value;
393+ }
394+ }
395+
396+ return absl::OkStatus ();
397+ }
398+
399+ absl::Status DynamicEmbeddingManager::TopK (
400+ const tensorflow::Tensor& input_activations, const int k,
401+ tensorflow::Tensor* output_keys, tensorflow::Tensor* output_logits) {
402+ RET_CHECK_TRUE (config_.embedding_dimension () > 0 )
403+ << " Invalid embedding dimension:" << config_.embedding_dimension ();
404+ RET_CHECK_TRUE (k > 0 );
405+
406+ // Shape of input: batch_size x hidden_size.
407+ const int dims = input_activations.dims ();
408+ const int inner_dim = input_activations.dim_size (dims - 1 );
409+ RET_CHECK_TRUE (inner_dim == config_.embedding_dimension ());
410+ const int batch_size =
411+ input_activations.NumElements () / config_.embedding_dimension ();
412+
413+ // Processes SampleRequest.
414+ SampleRequest sample_request;
415+ sample_request.set_session_handle (session_handle_);
416+ sample_request.set_num_samples (k);
417+ auto activation_value = input_activations.flat_inner_dims <float >();
418+ for (int b = 0 ; b < batch_size; ++b) {
419+ auto * sample_context = sample_request.add_sample_context ();
420+ for (int i = 0 ; i < config_.embedding_dimension (); ++i) {
421+ sample_context->mutable_activation ()->add_value (activation_value (b, i));
422+ }
423+ }
424+
425+ // Calls the Sample RPC.
426+ grpc::ClientContext context;
427+ context.set_deadline (std::chrono::system_clock::now () +
428+ absl::ToChronoSeconds (absl::Seconds (
429+ absl::GetFlag (FLAGS_kbs_rpc_deadline_sec))));
430+ SampleResponse sample_response;
431+ RET_CHECK_OK (stub_->Sample (&context, sample_request, &sample_response));
432+ RET_CHECK_TRUE (sample_response.samples_size () == batch_size);
433+
434+ // Process topk results.
435+ auto output_keys_values = output_keys->flat_inner_dims <tstring>();
436+ auto logits_values = output_logits->flat_inner_dims <float >();
437+ for (int b = 0 ; b < batch_size; ++b) {
438+ const auto & samples = sample_response.samples (b);
439+ RET_CHECK_TRUE (samples.sampled_result_size () == k);
440+ for (int i = 0 ; i < k; ++i) {
441+ auto & result = samples.sampled_result (i).topk_sampling_result ();
442+ logits_values (b, i) = result.similarity ();
443+ output_keys_values (b, i) = std::move (result.key ());
444+ }
445+ }
446+ return absl::OkStatus ();
447+ }
448+
303449absl::Status DynamicEmbeddingManager::Export (const std::string& output_dir,
304450 std::string* exported_path) {
305451 CHECK (exported_path != nullptr );
0 commit comments