Skip to content

Commit 4b226cb

Browse files
Neural-Link Teamtensorflow-copybara
authored andcommitted
Add uniform sampler and combine it with log_uniform sampler into negative_sampler.
PiperOrigin-RevId: 368921933
1 parent f17eb3a commit 4b226cb

10 files changed

+387
-233
lines changed

research/carls/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ cc_library(
153153
":knowledge_bank_service_cc_grpc_proto",
154154
"//research/carls/candidate_sampling:brute_force_topk_sampler",
155155
"//research/carls/candidate_sampling:candidate_sampler",
156-
"//research/carls/candidate_sampling:log_uniform_sampler",
156+
"//research/carls/candidate_sampling:negative_sampler",
157157
"//research/carls/gradient_descent:gradient_descent_optimizer",
158158
"//research/carls/knowledge_bank",
159159
"//research/carls/knowledge_bank:in_proto_knowledge_bank",

research/carls/candidate_sampling/BUILD

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ cc_test(
8484
)
8585

8686
cc_library(
87-
name = "log_uniform_sampler",
88-
srcs = ["log_uniform_sampler.cc"],
87+
name = "negative_sampler",
88+
srcs = ["negative_sampler.cc"],
8989
deps = [
9090
":candidate_sampler",
9191
"//research/carls/base:proto_helper",
@@ -98,8 +98,8 @@ cc_library(
9898
)
9999

100100
cc_test(
101-
name = "log_uniform_sampler_test",
102-
srcs = ["log_uniform_sampler_test.cc"],
101+
name = "negative_sampler_test",
102+
srcs = ["negative_sampler_test.cc"],
103103
deps = [
104104
":candidate_sampler",
105105
":log_uniform_sampler",

research/carls/candidate_sampling/candidate_sampler_config.proto

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,18 @@ message CandidateSamplerConfig {
1515
// A sampler used in the sampled softmax/logistic layer.
1616
// See https://www.tensorflow.org/extras/candidate_sampling.pdf for more
1717
// details.
18-
message LogUniformSamplerConfig {
18+
message NegativeSamplerConfig {
1919
// If the sampled key should be unique or not.
2020
bool unique = 1;
21+
22+
enum Sampler {
23+
UNKNOWN = 0;
24+
// Sample uniformly among both positive and negative data.
25+
UNIFORM = 1;
26+
// Sample log uniformly among both positive and negative data.
27+
LOG_UNIFORM = 2;
28+
}
29+
Sampler sampler = 2;
2130
}
2231

2332
// Types of similarity between two embeddings.

research/carls/candidate_sampling/candidate_sampler_config_builder.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,29 @@
1818
from research.carls.candidate_sampling import candidate_sampler_config_pb2 as cs_config_pb2
1919

2020

21-
def log_uniform_sampler(unique: bool) -> cs_config_pb2.LogUniformSamplerConfig:
22-
return cs_config_pb2.LogUniformSamplerConfig(unique=unique)
21+
def negative_sampler(unique: bool,
22+
algorithm) -> cs_config_pb2.NegativeSamplerConfig:
23+
"""Builds a NegativeSamplerConfig from given input.
24+
25+
Args:
26+
unique: a bool indicating if the samples should be unique.
27+
algorithm: the sampler algorithm defined by NegativeSamplerConfig.Sampler.
28+
29+
Returns:
30+
A NegativeSamplerConfig proto.
31+
"""
32+
if isinstance(algorithm, typing.Text):
33+
algorithm = cs_config_pb2.NegativeSamplerConfig.Sampler.Value(algorithm)
34+
if isinstance(algorithm, int):
35+
if algorithm not in [
36+
cs_config_pb2.NegativeSamplerConfig.UNIFORM,
37+
cs_config_pb2.NegativeSamplerConfig.LOG_UNIFORM
38+
]:
39+
raise ValueError('Invalid sampler type.')
40+
else:
41+
raise ValueError('Invalid input: %r' % algorithm)
42+
43+
return cs_config_pb2.NegativeSamplerConfig(unique=unique, sampler=algorithm)
2344

2445

2546
def brute_force_topk_sampler(
@@ -53,7 +74,7 @@ def build_candidate_sampler_config(
5374
"""Builds a CandidateSamplerConfig from given sampler.
5475
5576
Args:
56-
sampler: an instance of LogUniformSamplerConfig or
77+
sampler: an instance of NegativeSamplerConfig or
5778
BruteForceTopkSamplerConfig.
5879
5980
Returns:
@@ -62,10 +83,10 @@ def build_candidate_sampler_config(
6283
Raises:
6384
ValueError if `sampler` is not valid.
6485
"""
65-
if not (isinstance(sampler, cs_config_pb2.LogUniformSamplerConfig) or
86+
if not (isinstance(sampler, cs_config_pb2.NegativeSamplerConfig) or
6687
isinstance(sampler, cs_config_pb2.BruteForceTopkSamplerConfig)):
6788
raise ValueError(
68-
'sampler must be one of LogUniformSamplerConfig or BruteForceTopkSamplerConfig'
89+
'sampler must be one of NegativeSamplerConfig or BruteForceTopkSamplerConfig'
6990
)
7091

7192
sampler_config = cs_config_pb2.CandidateSamplerConfig()

research/carls/candidate_sampling/candidate_sampler_config_builder_test.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,45 @@
2121

2222
class CandidateSamplerConfigBuilderTest(tf.test.TestCase):
2323

24-
def test_log_uniform_sampler(self):
25-
self.assertProtoEquals("""
24+
def test_negative_sampler(self):
25+
self.assertProtoEquals(
26+
"""
2627
unique: true
27-
""", cs_config_builder.log_uniform_sampler(True))
28-
self.assertProtoEquals("""
28+
sampler: UNIFORM
29+
""",
30+
cs_config_builder.negative_sampler(
31+
True, cs_config_pb2.NegativeSamplerConfig.UNIFORM))
32+
self.assertProtoEquals(
33+
"""
34+
unique: false
35+
sampler: UNIFORM
36+
""",
37+
cs_config_builder.negative_sampler(
38+
False, cs_config_pb2.NegativeSamplerConfig.UNIFORM))
39+
self.assertProtoEquals(
40+
"""
41+
unique: true
42+
sampler: LOG_UNIFORM
43+
""",
44+
cs_config_builder.negative_sampler(
45+
True, cs_config_pb2.NegativeSamplerConfig.LOG_UNIFORM))
46+
self.assertProtoEquals(
47+
"""
2948
unique: false
30-
""", cs_config_builder.log_uniform_sampler(False))
49+
sampler: LOG_UNIFORM
50+
""",
51+
cs_config_builder.negative_sampler(
52+
False, cs_config_pb2.NegativeSamplerConfig.LOG_UNIFORM))
53+
self.assertProtoEquals(
54+
"""
55+
unique: false
56+
sampler: UNIFORM
57+
""", cs_config_builder.negative_sampler(False, 'UNIFORM'))
58+
self.assertProtoEquals(
59+
"""
60+
unique: true
61+
sampler: LOG_UNIFORM
62+
""", cs_config_builder.negative_sampler(True, 'LOG_UNIFORM'))
3163

3264
def test_brute_force_topk_sampler_success(self):
3365
self.assertProtoEquals("""
@@ -71,13 +103,14 @@ def test_build_candidate_sampler_config_success(self):
71103
self.assertProtoEquals(
72104
"""
73105
extension {
74-
[type.googleapis.com/carls.candidate_sampling.LogUniformSamplerConfig] {
106+
[type.googleapis.com/carls.candidate_sampling.NegativeSamplerConfig] {
75107
unique: true
108+
sampler: UNIFORM
76109
}
77110
}
78111
""",
79112
cs_config_builder.build_candidate_sampler_config(
80-
cs_config_builder.log_uniform_sampler(True)))
113+
cs_config_builder.negative_sampler(True, 'UNIFORM')))
81114

82115
def test_build_candidate_sampler_config_failed(self):
83116
with self.assertRaises(ValueError):

research/carls/candidate_sampling/log_uniform_sampler_test.cc

Lines changed: 0 additions & 189 deletions
This file was deleted.

0 commit comments

Comments
 (0)