Skip to content

Commit 4a574b8

Browse files
Neural-Link Teamtensorflow-copybara
authored andcommitted
Add candidate sampling ops
PiperOrigin-RevId: 368953809
1 parent debdd98 commit 4a574b8

18 files changed

+917
-94
lines changed

research/carls/BUILD

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,9 +276,38 @@ py_test(
276276
deps = [
277277
":context",
278278
":dynamic_embedding_ops_py",
279+
"//research/carls/testing:test_util",
279280
"@com_google_absl_py//absl/testing:parameterized",
281+
],
282+
)
283+
284+
py_library(
285+
name = "candidate_sampling_ops_py",
286+
srcs = ["candidate_sampling_ops.py"],
287+
srcs_version = "PY3",
288+
visibility = ["//visibility:public"],
289+
deps = [
290+
":context",
291+
":dynamic_embedding_config_py_pb2",
292+
"//research/carls/kernels:gen_dynamic_embedding_ops_py",
293+
"//research/carls/kernels:gen_sampled_logits_ops_py",
294+
"//research/carls/kernels:gen_topk_ops_py",
295+
],
296+
)
297+
298+
py_test(
299+
name = "candidate_sampling_ops_test",
300+
size = "small",
301+
srcs = ["candidate_sampling_ops_test.py"],
302+
python_version = "PY3",
303+
srcs_version = "PY3",
304+
deps = [
305+
":candidate_sampling_ops_py",
306+
":context",
307+
":dynamic_embedding_ops_py",
308+
"//research/carls/candidate_sampling:candidate_sampler_config_builder_py",
280309
"//research/carls/testing:test_util",
281-
# package tensorflow
310+
"@com_google_absl_py//absl/testing:parameterized",
282311
],
283312
)
284313

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Candidate sampling related ops."""
15+
16+
import typing
17+
18+
from research.carls import context
19+
from research.carls import dynamic_embedding_config_pb2 as de_config_pb2
20+
from research.carls.kernels import gen_dynamic_embedding_ops as de_ops
21+
from research.carls.kernels import gen_sampled_logits_ops
22+
from research.carls.kernels import gen_topk_ops as gen_topk_op
23+
import tensorflow as tf
24+
25+
26+
def top_k(inputs: tf.Tensor,
27+
k: int,
28+
de_config: de_config_pb2.DynamicEmbeddingConfig,
29+
var_name: typing.Text,
30+
service_address: typing.Text = "",
31+
timeout_ms: int = -1):
32+
"""Computes logits for the top k closest embeddings to the inputs.
33+
34+
Args:
35+
inputs: A float `Tensor` of shape `[batch_size, dim]` representing the
36+
forward activations of the input network.
37+
k: An `int` denoting the number of returned keys.
38+
de_config: A DynamicEmbeddingConfig for configuring the dynamic embedding.
39+
var_name: A unique name for the operation.
40+
service_address: The address of a dynamic embedding service. If empty, the
41+
value passed from --kbs_address flag will be used instead.
42+
timeout_ms: Timeout millseconds for the connection. If negative, never
43+
timout.
44+
45+
Returns:
46+
keys: A string `Tensor` of shape `[batch_size, k]` representing the top k
47+
keys relative to the input.
48+
logits: A float `Tensor` of shape `[batch_size, k]` representing the logits
49+
for the returned keys.
50+
51+
Raises:
52+
ValueError: if k is not greater than zero.
53+
54+
Note: The (keys, logits) pair returned here should not be used for training as
55+
they only represent biased sampling. Instead, use sampled_softmax_loss()
56+
for training.
57+
"""
58+
if not var_name:
59+
raise ValueError("Must specify a valid var_name.")
60+
if k <= 0:
61+
raise ValueError("k must be greater than zero, got %d" % k)
62+
63+
context.add_to_collection(var_name, de_config)
64+
resource = de_ops.dynamic_embedding_manager_resource(
65+
de_config.SerializeToString(), var_name, service_address, timeout_ms)
66+
return gen_topk_op.topk_lookup(inputs, k, resource)
67+
68+
69+
def sampled_softmax_loss(positive_keys: tf.Tensor,
70+
inputs: tf.Tensor,
71+
num_samples: int,
72+
de_config: de_config_pb2.DynamicEmbeddingConfig,
73+
var_name: typing.Text,
74+
service_address: typing.Text = "",
75+
timeout_ms: int = -1):
76+
"""Compute sampled Softmax loss from given input activations.
77+
78+
Args:
79+
positive_keys: A string `Tensor` of shape `[batch_size, None]` representing
80+
input positive keys.
81+
inputs: A float `Tensor` of shape `[batch_size, dim]`, representing the
82+
forward activations of the input network.
83+
num_samples: An int denoting the returned positive and negative samples.
84+
de_config: A DynamicEmbeddingConfig for configuring the dynamic embedding.
85+
var_name: A unique name for the operation.
86+
service_address: The address of a dynamic embedding service. If empty, the
87+
value passed from --kbs_address flag will be used instead.
88+
timeout_ms: Timeout millseconds for the connection. If negative, never
89+
timout.
90+
91+
Returns:
92+
A float `Tensor` representing the sampled softmax loss.
93+
"""
94+
logits, labels, _, mask, _ = compute_sampled_logits(positive_keys, inputs,
95+
num_samples, de_config,
96+
var_name, service_address,
97+
timeout_ms)
98+
tiled_norm = tf.tile(
99+
tf.maximum(tf.reduce_sum(labels, -1, keepdims=True), 1),
100+
[1, labels.get_shape()[-1]])
101+
labels /= tiled_norm
102+
return tf.reduce_sum(
103+
tf.nn.softmax_cross_entropy_with_logits_v2(
104+
labels=labels, logits=logits)) / tf.reduce_sum(mask)
105+
106+
107+
def sampled_sigmoid_loss(positive_keys: tf.Tensor,
108+
inputs: tf.Tensor,
109+
num_samples: int,
110+
de_config: de_config_pb2.DynamicEmbeddingConfig,
111+
var_name: typing.Text,
112+
service_address: typing.Text = "",
113+
timeout_ms: int = -1):
114+
"""Compute sampled sigmoid loss from given input activations.
115+
116+
Args:
117+
positive_keys: A string `Tensor` of shape `[batch_size, None]` representing
118+
input positive keys.
119+
inputs: A float `Tensor` of shape `[batch_size, dim]`, representing the
120+
forward activations of the input network.
121+
num_samples: An int denoting the returned positive and negative samples.
122+
de_config: A DynamicEmbeddingConfig for configuring the dynamic embedding.
123+
var_name: A unique name for the operation.
124+
service_address: The address of a dynamic embedding service. If empty, the
125+
value passed from --kbs_address flag will be used instead.
126+
timeout_ms: Timeout millseconds for the connection. If negative, never
127+
timout.
128+
129+
Returns:
130+
A float `Tensor` representing the sampled sigmoid loss.
131+
"""
132+
logits, labels, _, mask, _ = compute_sampled_logits(positive_keys, inputs,
133+
num_samples, de_config,
134+
var_name, service_address,
135+
timeout_ms)
136+
tiled_norm = tf.tile(
137+
tf.maximum(tf.reduce_sum(labels, -1, keepdims=True), 1),
138+
[1, labels.get_shape()[-1]])
139+
labels /= tiled_norm
140+
reduced_sum = tf.reduce_sum(
141+
tf.nn.sigmoid_cross_entropy_with_logits(
142+
labels=labels, logits=logits)) / tf.reduce_sum(mask)
143+
return reduced_sum / num_samples
144+
145+
146+
def compute_sampled_logits(positive_keys,
147+
inputs,
148+
num_samples: int,
149+
de_config: de_config_pb2.DynamicEmbeddingConfig,
150+
var_name: typing.Text,
151+
service_address: typing.Text = "",
152+
timeout_ms: int = -1):
153+
"""Computes sampled logits from given positive labels.
154+
155+
Args:
156+
positive_keys: A string `Tensor` of shape `[batch_size, None]` representing
157+
input positive keys.
158+
inputs: A float `Tensor` of shape `[batch_size, dim]` representing the
159+
forward activations of the input network.
160+
num_samples: An int denoting the returned positive and negative samples.
161+
de_config: A DynamicEmbeddingConfig for configuring the dynamic embedding.
162+
var_name: A unique name for the operation.
163+
service_address: The address of a dynamic embedding service. If empty, the
164+
value passed from --kbs_address flag will be used instead.
165+
timeout_ms: Timeout millseconds for the connection. If negative, never
166+
timout.
167+
168+
Returns:
169+
logits: A float `Tensor` of shape `[batch_size, num_samples]` representing
170+
the logits for sampled labels.
171+
labels: A float `Tensor` of shape `[batch_size, num_samples]` with values
172+
in {0, 1} indicating if the sample is positive or negative.
173+
keys: A string `Tensor` of shape `[batch_size, num_samples]` representing
174+
the keys for each sample.
175+
mask: A float `Tensor` of shape `[batch_size]` representing the 0/1 mask
176+
of each batch. For example, if all keys in positive_keys[i] are empty,
177+
mask[i] = 0; otherwise mask[i] = 1.
178+
weights: A float `Tensor` representing the embeddings of the sampled keys.
179+
180+
Raises:
181+
ValueError: If var_name is not specified.
182+
TypeError: If de_config is an instance of DynamicEmbeddingConfig.
183+
"""
184+
if not var_name:
185+
raise ValueError("Must specify a valid name, got %s" % var_name)
186+
if num_samples < 1:
187+
raise ValueError("Invalid num_samples: %d" % num_samples)
188+
189+
context.add_to_collection(var_name, de_config)
190+
resource = de_ops.dynamic_embedding_manager_resource(
191+
de_config.SerializeToString(), var_name, service_address, timeout_ms)
192+
193+
# Create a dummy variable so that the gradients can be passed in.
194+
grad_placeholder = tf.Variable(0.0)
195+
196+
keys, labels, expected_counts, mask, weights = (
197+
gen_sampled_logits_ops.sampled_logits_lookup(positive_keys, inputs,
198+
num_samples,
199+
grad_placeholder, resource))
200+
201+
# Compute sampled logits.
202+
# Shape of weights: [d1, d2, dn-1, num_samples, embed_dim]
203+
# Shape of inputs: [d1, d2, dn-1, embed_dim]
204+
# Shape of output logits: [d1, d2, dn-1, num_samples]
205+
206+
# [d1, d2, dn-1, embed_dim] -> [d1, d2, dn-1, 1, embed_dim]
207+
tiled_inputs = tf.expand_dims(inputs, axis=-2)
208+
# [d1, d2, dn-1, embed_dim] -> [d1, d2, dn-1, num_samples, embed_dim]
209+
multiples = [1] * (inputs.ndim + 1)
210+
multiples[-2] = num_samples
211+
tiled_inputs = tf.tile(tiled_inputs, multiples)
212+
# [d1, d2, dn-1, num_samples, embed_dim] -> [d1, d2, dn-1, num_samples]
213+
logits = tf.reduce_sum(weights * tiled_inputs, -1)
214+
# Sampled logits.
215+
logits -= tf.math.log(expected_counts)
216+
217+
return logits, labels, keys, mask, weights
218+
219+
220+
@tf.RegisterGradient("SampledLogitsLookup")
221+
def _sampled_logits_lookup_grad(op, keys_grad, labels_grad,
222+
expected_counts_grad, mask_grad, weights_grad):
223+
"""Computes the gradients for SampledLogitsLookup.
224+
225+
We uses the gradients w.r.t. the weights output of sampled_logits_lookup() to
226+
update the embeddings/weights of the sampled keys.
227+
The gradients for the inputs of sampled_logits_lookup should be provided, but
228+
none of them needs to be back-propagated. So we set all of them to be zeros.
229+
230+
Args:
231+
op: The DynamicEmbeddingLookup op.
232+
keys_grad: The tensor representing the gradient w.r.t. the keys output.
233+
labels_grad: The tensor representing the gradient w.r.t. the labels output.
234+
expected_counts_grad: The tensor representing the gradient w.r.t. the
235+
expected_counts output.
236+
mask_grad: The tensor representing the gradient w.r.t. the mask output.
237+
weights_grad: The tensor representing the gradient w.r.t. the weights
238+
output.
239+
240+
Returns:
241+
The gradients w.r.t. the input.
242+
"""
243+
del keys_grad, labels_grad, expected_counts_grad, mask_grad # Unused.
244+
245+
pos_keys_grad, num_samples_grad, dummy_variable_grad, resource_grad = (
246+
gen_sampled_logits_ops.sampled_logits_lookup_grad(
247+
keys=op.outputs[0],
248+
weight_gradients=weights_grad,
249+
handle=op.inputs[4]))
250+
# Gradient for the input activation.
251+
inputs_grad = tf.zeros_like(op.inputs[1])
252+
return (pos_keys_grad, inputs_grad, num_samples_grad, dummy_variable_grad,
253+
resource_grad)

0 commit comments

Comments
 (0)