|
| 1 | +# Copyright 2021 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +"""Candidate sampling related ops.""" |
| 15 | + |
| 16 | +import typing |
| 17 | + |
| 18 | +from research.carls import context |
| 19 | +from research.carls import dynamic_embedding_config_pb2 as de_config_pb2 |
| 20 | +from research.carls.kernels import gen_dynamic_embedding_ops as de_ops |
| 21 | +from research.carls.kernels import gen_sampled_logits_ops |
| 22 | +from research.carls.kernels import gen_topk_ops as gen_topk_op |
| 23 | +import tensorflow as tf |
| 24 | + |
| 25 | + |
| 26 | +def top_k(inputs: tf.Tensor, |
| 27 | + k: int, |
| 28 | + de_config: de_config_pb2.DynamicEmbeddingConfig, |
| 29 | + var_name: typing.Text, |
| 30 | + service_address: typing.Text = "", |
| 31 | + timeout_ms: int = -1): |
| 32 | + """Computes logits for the top k closest embeddings to the inputs. |
| 33 | +
|
| 34 | + Args: |
| 35 | + inputs: A float `Tensor` of shape `[batch_size, dim]` representing the |
| 36 | + forward activations of the input network. |
| 37 | + k: An `int` denoting the number of returned keys. |
| 38 | + de_config: A DynamicEmbeddingConfig for configuring the dynamic embedding. |
| 39 | + var_name: A unique name for the operation. |
| 40 | + service_address: The address of a dynamic embedding service. If empty, the |
| 41 | + value passed from --kbs_address flag will be used instead. |
| 42 | + timeout_ms: Timeout millseconds for the connection. If negative, never |
| 43 | + timout. |
| 44 | +
|
| 45 | + Returns: |
| 46 | + keys: A string `Tensor` of shape `[batch_size, k]` representing the top k |
| 47 | + keys relative to the input. |
| 48 | + logits: A float `Tensor` of shape `[batch_size, k]` representing the logits |
| 49 | + for the returned keys. |
| 50 | +
|
| 51 | + Raises: |
| 52 | + ValueError: if k is not greater than zero. |
| 53 | +
|
| 54 | + Note: The (keys, logits) pair returned here should not be used for training as |
| 55 | + they only represent biased sampling. Instead, use sampled_softmax_loss() |
| 56 | + for training. |
| 57 | + """ |
| 58 | + if not var_name: |
| 59 | + raise ValueError("Must specify a valid var_name.") |
| 60 | + if k <= 0: |
| 61 | + raise ValueError("k must be greater than zero, got %d" % k) |
| 62 | + |
| 63 | + context.add_to_collection(var_name, de_config) |
| 64 | + resource = de_ops.dynamic_embedding_manager_resource( |
| 65 | + de_config.SerializeToString(), var_name, service_address, timeout_ms) |
| 66 | + return gen_topk_op.topk_lookup(inputs, k, resource) |
| 67 | + |
| 68 | + |
| 69 | +def sampled_softmax_loss(positive_keys: tf.Tensor, |
| 70 | + inputs: tf.Tensor, |
| 71 | + num_samples: int, |
| 72 | + de_config: de_config_pb2.DynamicEmbeddingConfig, |
| 73 | + var_name: typing.Text, |
| 74 | + service_address: typing.Text = "", |
| 75 | + timeout_ms: int = -1): |
| 76 | + """Compute sampled Softmax loss from given input activations. |
| 77 | +
|
| 78 | + Args: |
| 79 | + positive_keys: A string `Tensor` of shape `[batch_size, None]` representing |
| 80 | + input positive keys. |
| 81 | + inputs: A float `Tensor` of shape `[batch_size, dim]`, representing the |
| 82 | + forward activations of the input network. |
| 83 | + num_samples: An int denoting the returned positive and negative samples. |
| 84 | + de_config: A DynamicEmbeddingConfig for configuring the dynamic embedding. |
| 85 | + var_name: A unique name for the operation. |
| 86 | + service_address: The address of a dynamic embedding service. If empty, the |
| 87 | + value passed from --kbs_address flag will be used instead. |
| 88 | + timeout_ms: Timeout millseconds for the connection. If negative, never |
| 89 | + timout. |
| 90 | +
|
| 91 | + Returns: |
| 92 | + A float `Tensor` representing the sampled softmax loss. |
| 93 | + """ |
| 94 | + logits, labels, _, mask, _ = compute_sampled_logits(positive_keys, inputs, |
| 95 | + num_samples, de_config, |
| 96 | + var_name, service_address, |
| 97 | + timeout_ms) |
| 98 | + tiled_norm = tf.tile( |
| 99 | + tf.maximum(tf.reduce_sum(labels, -1, keepdims=True), 1), |
| 100 | + [1, labels.get_shape()[-1]]) |
| 101 | + labels /= tiled_norm |
| 102 | + return tf.reduce_sum( |
| 103 | + tf.nn.softmax_cross_entropy_with_logits_v2( |
| 104 | + labels=labels, logits=logits)) / tf.reduce_sum(mask) |
| 105 | + |
| 106 | + |
| 107 | +def sampled_sigmoid_loss(positive_keys: tf.Tensor, |
| 108 | + inputs: tf.Tensor, |
| 109 | + num_samples: int, |
| 110 | + de_config: de_config_pb2.DynamicEmbeddingConfig, |
| 111 | + var_name: typing.Text, |
| 112 | + service_address: typing.Text = "", |
| 113 | + timeout_ms: int = -1): |
| 114 | + """Compute sampled sigmoid loss from given input activations. |
| 115 | +
|
| 116 | + Args: |
| 117 | + positive_keys: A string `Tensor` of shape `[batch_size, None]` representing |
| 118 | + input positive keys. |
| 119 | + inputs: A float `Tensor` of shape `[batch_size, dim]`, representing the |
| 120 | + forward activations of the input network. |
| 121 | + num_samples: An int denoting the returned positive and negative samples. |
| 122 | + de_config: A DynamicEmbeddingConfig for configuring the dynamic embedding. |
| 123 | + var_name: A unique name for the operation. |
| 124 | + service_address: The address of a dynamic embedding service. If empty, the |
| 125 | + value passed from --kbs_address flag will be used instead. |
| 126 | + timeout_ms: Timeout millseconds for the connection. If negative, never |
| 127 | + timout. |
| 128 | +
|
| 129 | + Returns: |
| 130 | + A float `Tensor` representing the sampled sigmoid loss. |
| 131 | + """ |
| 132 | + logits, labels, _, mask, _ = compute_sampled_logits(positive_keys, inputs, |
| 133 | + num_samples, de_config, |
| 134 | + var_name, service_address, |
| 135 | + timeout_ms) |
| 136 | + tiled_norm = tf.tile( |
| 137 | + tf.maximum(tf.reduce_sum(labels, -1, keepdims=True), 1), |
| 138 | + [1, labels.get_shape()[-1]]) |
| 139 | + labels /= tiled_norm |
| 140 | + reduced_sum = tf.reduce_sum( |
| 141 | + tf.nn.sigmoid_cross_entropy_with_logits( |
| 142 | + labels=labels, logits=logits)) / tf.reduce_sum(mask) |
| 143 | + return reduced_sum / num_samples |
| 144 | + |
| 145 | + |
| 146 | +def compute_sampled_logits(positive_keys, |
| 147 | + inputs, |
| 148 | + num_samples: int, |
| 149 | + de_config: de_config_pb2.DynamicEmbeddingConfig, |
| 150 | + var_name: typing.Text, |
| 151 | + service_address: typing.Text = "", |
| 152 | + timeout_ms: int = -1): |
| 153 | + """Computes sampled logits from given positive labels. |
| 154 | +
|
| 155 | + Args: |
| 156 | + positive_keys: A string `Tensor` of shape `[batch_size, None]` representing |
| 157 | + input positive keys. |
| 158 | + inputs: A float `Tensor` of shape `[batch_size, dim]` representing the |
| 159 | + forward activations of the input network. |
| 160 | + num_samples: An int denoting the returned positive and negative samples. |
| 161 | + de_config: A DynamicEmbeddingConfig for configuring the dynamic embedding. |
| 162 | + var_name: A unique name for the operation. |
| 163 | + service_address: The address of a dynamic embedding service. If empty, the |
| 164 | + value passed from --kbs_address flag will be used instead. |
| 165 | + timeout_ms: Timeout millseconds for the connection. If negative, never |
| 166 | + timout. |
| 167 | +
|
| 168 | + Returns: |
| 169 | + logits: A float `Tensor` of shape `[batch_size, num_samples]` representing |
| 170 | + the logits for sampled labels. |
| 171 | + labels: A float `Tensor` of shape `[batch_size, num_samples]` with values |
| 172 | + in {0, 1} indicating if the sample is positive or negative. |
| 173 | + keys: A string `Tensor` of shape `[batch_size, num_samples]` representing |
| 174 | + the keys for each sample. |
| 175 | + mask: A float `Tensor` of shape `[batch_size]` representing the 0/1 mask |
| 176 | + of each batch. For example, if all keys in positive_keys[i] are empty, |
| 177 | + mask[i] = 0; otherwise mask[i] = 1. |
| 178 | + weights: A float `Tensor` representing the embeddings of the sampled keys. |
| 179 | +
|
| 180 | + Raises: |
| 181 | + ValueError: If var_name is not specified. |
| 182 | + TypeError: If de_config is an instance of DynamicEmbeddingConfig. |
| 183 | + """ |
| 184 | + if not var_name: |
| 185 | + raise ValueError("Must specify a valid name, got %s" % var_name) |
| 186 | + if num_samples < 1: |
| 187 | + raise ValueError("Invalid num_samples: %d" % num_samples) |
| 188 | + |
| 189 | + context.add_to_collection(var_name, de_config) |
| 190 | + resource = de_ops.dynamic_embedding_manager_resource( |
| 191 | + de_config.SerializeToString(), var_name, service_address, timeout_ms) |
| 192 | + |
| 193 | + # Create a dummy variable so that the gradients can be passed in. |
| 194 | + grad_placeholder = tf.Variable(0.0) |
| 195 | + |
| 196 | + keys, labels, expected_counts, mask, weights = ( |
| 197 | + gen_sampled_logits_ops.sampled_logits_lookup(positive_keys, inputs, |
| 198 | + num_samples, |
| 199 | + grad_placeholder, resource)) |
| 200 | + |
| 201 | + # Compute sampled logits. |
| 202 | + # Shape of weights: [d1, d2, dn-1, num_samples, embed_dim] |
| 203 | + # Shape of inputs: [d1, d2, dn-1, embed_dim] |
| 204 | + # Shape of output logits: [d1, d2, dn-1, num_samples] |
| 205 | + |
| 206 | + # [d1, d2, dn-1, embed_dim] -> [d1, d2, dn-1, 1, embed_dim] |
| 207 | + tiled_inputs = tf.expand_dims(inputs, axis=-2) |
| 208 | + # [d1, d2, dn-1, embed_dim] -> [d1, d2, dn-1, num_samples, embed_dim] |
| 209 | + multiples = [1] * (inputs.ndim + 1) |
| 210 | + multiples[-2] = num_samples |
| 211 | + tiled_inputs = tf.tile(tiled_inputs, multiples) |
| 212 | + # [d1, d2, dn-1, num_samples, embed_dim] -> [d1, d2, dn-1, num_samples] |
| 213 | + logits = tf.reduce_sum(weights * tiled_inputs, -1) |
| 214 | + # Sampled logits. |
| 215 | + logits -= tf.math.log(expected_counts) |
| 216 | + |
| 217 | + return logits, labels, keys, mask, weights |
| 218 | + |
| 219 | + |
| 220 | +@tf.RegisterGradient("SampledLogitsLookup") |
| 221 | +def _sampled_logits_lookup_grad(op, keys_grad, labels_grad, |
| 222 | + expected_counts_grad, mask_grad, weights_grad): |
| 223 | + """Computes the gradients for SampledLogitsLookup. |
| 224 | +
|
| 225 | + We uses the gradients w.r.t. the weights output of sampled_logits_lookup() to |
| 226 | + update the embeddings/weights of the sampled keys. |
| 227 | + The gradients for the inputs of sampled_logits_lookup should be provided, but |
| 228 | + none of them needs to be back-propagated. So we set all of them to be zeros. |
| 229 | +
|
| 230 | + Args: |
| 231 | + op: The DynamicEmbeddingLookup op. |
| 232 | + keys_grad: The tensor representing the gradient w.r.t. the keys output. |
| 233 | + labels_grad: The tensor representing the gradient w.r.t. the labels output. |
| 234 | + expected_counts_grad: The tensor representing the gradient w.r.t. the |
| 235 | + expected_counts output. |
| 236 | + mask_grad: The tensor representing the gradient w.r.t. the mask output. |
| 237 | + weights_grad: The tensor representing the gradient w.r.t. the weights |
| 238 | + output. |
| 239 | +
|
| 240 | + Returns: |
| 241 | + The gradients w.r.t. the input. |
| 242 | + """ |
| 243 | + del keys_grad, labels_grad, expected_counts_grad, mask_grad # Unused. |
| 244 | + |
| 245 | + pos_keys_grad, num_samples_grad, dummy_variable_grad, resource_grad = ( |
| 246 | + gen_sampled_logits_ops.sampled_logits_lookup_grad( |
| 247 | + keys=op.outputs[0], |
| 248 | + weight_gradients=weights_grad, |
| 249 | + handle=op.inputs[4])) |
| 250 | + # Gradient for the input activation. |
| 251 | + inputs_grad = tf.zeros_like(op.inputs[1]) |
| 252 | + return (pos_keys_grad, inputs_grad, num_samples_grad, dummy_variable_grad, |
| 253 | + resource_grad) |
0 commit comments