Skip to content

Commit 0544f9e

Browse files
Neural-Link Teamtensorflow-copybara
authored andcommitted
Add a candidate_sampler_config_builder to help build protos.
PiperOrigin-RevId: 367713072
1 parent 2b7127c commit 0544f9e

File tree

3 files changed

+184
-0
lines changed

3 files changed

+184
-0
lines changed

research/carls/candidate_sampling/BUILD

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,27 @@ carls_py_proto_library(
3939
],
4040
)
4141

42+
py_library(
43+
name = "candidate_sampler_config_builder_py",
44+
srcs = ["candidate_sampler_config_builder.py"],
45+
srcs_version = "PY3",
46+
deps = [
47+
":candidate_sampler_config_py_pb2",
48+
],
49+
)
50+
51+
py_test(
52+
name = "candidate_sampler_config_builder_test",
53+
size = "small",
54+
srcs = ["candidate_sampler_config_builder_test.py"],
55+
python_version = "PY3",
56+
srcs_version = "PY3",
57+
deps = [
58+
":candidate_sampler_config_builder_py",
59+
":candidate_sampler_config_py_pb2",
60+
],
61+
)
62+
4263
cc_library(
4364
name = "candidate_sampler",
4465
srcs = ["candidate_sampler.cc"],
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""A library for building a CandidateSamplerConfig."""
15+
16+
import typing
17+
18+
from research.carls.candidate_sampling import candidate_sampler_config_pb2 as cs_config_pb2
19+
20+
21+
def log_uniform_sampler(unique: bool) -> cs_config_pb2.LogUniformSamplerConfig:
22+
return cs_config_pb2.LogUniformSamplerConfig(unique=unique)
23+
24+
25+
def brute_force_topk_sampler(
26+
similarity_type) -> cs_config_pb2.BruteForceTopkSamplerConfig:
27+
"""Returns a BruteForceTopkSamplerConfig based on given similarity type.
28+
29+
Args:
30+
similarity_type: A string or an int indicating the type of similarity
31+
defined in carls.candidate_sampling.SimilarityType.
32+
33+
Returns:
34+
An instance of BruteForceTopkSamplerConfig if input is valid.
35+
36+
Raises:
37+
ValueError: if input is invalid.
38+
"""
39+
if isinstance(similarity_type, typing.Text):
40+
similarity_type = cs_config_pb2.SimilarityType.Value(similarity_type)
41+
if isinstance(similarity_type, int):
42+
if similarity_type not in [cs_config_pb2.COSINE, cs_config_pb2.DOT_PRODUCT]:
43+
raise ValueError('Invalid similarity type.')
44+
else:
45+
raise ValueError('Invalid input: %r' % similarity_type)
46+
47+
return cs_config_pb2.BruteForceTopkSamplerConfig(
48+
similarity_type=similarity_type)
49+
50+
51+
def build_candidate_sampler_config(
52+
sampler) -> cs_config_pb2.CandidateSamplerConfig:
53+
"""Builds a CandidateSamplerConfig from given sampler.
54+
55+
Args:
56+
sampler: an instance of LogUniformSamplerConfig or
57+
BruteForceTopkSamplerConfig.
58+
59+
Returns:
60+
A valid CandidateSamplerConfig.
61+
62+
Raises:
63+
ValueError if `sampler` is not valid.
64+
"""
65+
if not (isinstance(sampler, cs_config_pb2.LogUniformSamplerConfig) or
66+
isinstance(sampler, cs_config_pb2.BruteForceTopkSamplerConfig)):
67+
raise ValueError(
68+
'sampler must be one of LogUniformSamplerConfig or BruteForceTopkSamplerConfig'
69+
)
70+
71+
sampler_config = cs_config_pb2.CandidateSamplerConfig()
72+
sampler_config.extension.Pack(sampler)
73+
return sampler_config
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Tests for candidate_sampler_config_builder."""
15+
16+
from research.carls.candidate_sampling import candidate_sampler_config_builder as cs_config_builder
17+
from research.carls.candidate_sampling import candidate_sampler_config_pb2 as cs_config_pb2
18+
19+
import tensorflow as tf
20+
21+
22+
class CandidateSamplerConfigBuilderTest(tf.test.TestCase):
23+
24+
def test_log_uniform_sampler(self):
25+
self.assertProtoEquals("""
26+
unique: true
27+
""", cs_config_builder.log_uniform_sampler(True))
28+
self.assertProtoEquals("""
29+
unique: false
30+
""", cs_config_builder.log_uniform_sampler(False))
31+
32+
def test_brute_force_topk_sampler_success(self):
33+
self.assertProtoEquals("""
34+
similarity_type: COSINE
35+
""", cs_config_builder.brute_force_topk_sampler('COSINE'))
36+
self.assertProtoEquals(
37+
"""
38+
similarity_type: COSINE
39+
""", cs_config_builder.brute_force_topk_sampler(cs_config_pb2.COSINE))
40+
self.assertProtoEquals(
41+
"""
42+
similarity_type: DOT_PRODUCT
43+
""", cs_config_builder.brute_force_topk_sampler('DOT_PRODUCT'))
44+
self.assertProtoEquals(
45+
"""
46+
similarity_type: DOT_PRODUCT
47+
""", cs_config_builder.brute_force_topk_sampler(cs_config_pb2.DOT_PRODUCT))
48+
49+
def test_brute_force_topk_sampler_failed(self):
50+
with self.assertRaises(ValueError):
51+
cs_config_builder.brute_force_topk_sampler(cs_config_pb2.UNKNOWN)
52+
with self.assertRaises(ValueError):
53+
cs_config_builder.brute_force_topk_sampler('Unknown type string')
54+
with self.assertRaises(ValueError):
55+
cs_config_builder.brute_force_topk_sampler(cs_config_pb2.SampleContext())
56+
with self.assertRaises(ValueError):
57+
cs_config_builder.brute_force_topk_sampler(999)
58+
59+
def test_build_candidate_sampler_config_success(self):
60+
self.assertProtoEquals(
61+
"""
62+
extension {
63+
[type.googleapis.com/carls.candidate_sampling.BruteForceTopkSamplerConfig] {
64+
similarity_type: COSINE
65+
}
66+
}
67+
""",
68+
cs_config_builder.build_candidate_sampler_config(
69+
cs_config_builder.brute_force_topk_sampler('COSINE')))
70+
71+
self.assertProtoEquals(
72+
"""
73+
extension {
74+
[type.googleapis.com/carls.candidate_sampling.LogUniformSamplerConfig] {
75+
unique: true
76+
}
77+
}
78+
""",
79+
cs_config_builder.build_candidate_sampler_config(
80+
cs_config_builder.log_uniform_sampler(True)))
81+
82+
def test_build_candidate_sampler_config_failed(self):
83+
with self.assertRaises(ValueError):
84+
cs_config_builder.build_candidate_sampler_config(100)
85+
with self.assertRaises(ValueError):
86+
cs_config_builder.build_candidate_sampler_config('invalid')
87+
88+
89+
if __name__ == '__main__':
90+
tf.test.main()

0 commit comments

Comments
 (0)