Skip to content

Commit 2b7127c

Browse files
Neural-Link Teamtensorflow-copybara
authored andcommitted
Add a Contains() interface to KnowledgeBank for probing the keys without affecting its states.
PiperOrigin-RevId: 367657892
1 parent 89b56e1 commit 2b7127c

File tree

8 files changed

+78
-47
lines changed

8 files changed

+78
-47
lines changed

research/carls/candidate_sampling/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ cc_test(
112112
"//research/carls/knowledge_bank:initializer_cc_proto",
113113
"//research/carls/knowledge_bank:initializer_helper",
114114
"//research/carls/testing:test_helper",
115+
"@com_google_absl//absl/strings",
115116
"@com_google_googletest//:gtest_main",
116117
],
117118
)

research/carls/candidate_sampling/brute_force_topk_sampler_test.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515

1616
#include "gmock/gmock.h"
1717
#include "gtest/gtest.h"
18+
#include "absl/strings/string_view.h"
1819
#include "research/carls/candidate_sampling/candidate_sampler.h"
1920
#include "research/carls/embedding.pb.h" // proto to pb
2021
#include "research/carls/knowledge_bank/initializer.pb.h" // proto to pb
@@ -80,6 +81,8 @@ class FakeEmbedding : public KnowledgeBank {
8081

8182
std::vector<absl::string_view> Keys() const { return keys_; }
8283

84+
bool Contains(absl::string_view key) const { return true; }
85+
8386
private:
8487
InProtoKnowledgeBankConfig::EmbeddingData data_table_;
8588
std::vector<absl::string_view> keys_;
@@ -193,9 +196,8 @@ TEST_F(BruteForceTopkSamplerTest, DotProtudctSimilarity) {
193196
context.mutable_activation()->add_value(2);
194197

195198
std::vector<SampledResult> results;
196-
ASSERT_TRUE(
197-
sampler->Sample(*knowledge_bank, context, /*num_samples=*/1, &results)
198-
.ok());
199+
ASSERT_OK(
200+
sampler->Sample(*knowledge_bank, context, /*num_samples=*/1, &results));
199201
ASSERT_EQ(1, results.size());
200202
EXPECT_THAT(results[0], EqualsProto<SampledResult>(R"pb(
201203
topk_sampling_result {

research/carls/candidate_sampling/log_uniform_sampler_test.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ class FakeKnowledgeBank : public KnowledgeBank {
6969
return absl::OkStatus();
7070
}
7171

72+
// Never called.
73+
bool Contains(absl::string_view key) const { return true; }
74+
7275
private:
7376
std::vector<std::string> str_keys_;
7477
std::vector<absl::string_view> keys_;
@@ -101,10 +104,8 @@ TEST_F(LogUniformSamplerTest, SampingWithReplacement) {
101104
std::vector<SampledResult> results;
102105

103106
// Same number of positives and num_samples = num_total_keys.
104-
ASSERT_TRUE(sampler
105-
->Sample(FakeKnowledgeBank(/*num_keys=*/2), context,
106-
/*num_samples=*/2, &results)
107-
.ok());
107+
ASSERT_OK(sampler->Sample(FakeKnowledgeBank(/*num_keys=*/2), context,
108+
/*num_samples=*/2, &results));
108109
ASSERT_EQ(2, results.size());
109110
EXPECT_TRUE(results[0].negative_sampling_result().is_positive());
110111
EXPECT_TRUE(results[1].negative_sampling_result().is_positive());

research/carls/knowledge_bank/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ cc_library(
131131
"//research/carls/base:file_helper",
132132
"//research/carls/base:proto_helper",
133133
"@com_google_absl//absl/status",
134+
"@com_google_absl//absl/strings",
134135
"@com_google_absl//absl/synchronization",
135136
],
136137
alwayslink = 1,

research/carls/knowledge_bank/in_proto_knowledge_bank.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License.
1414
==============================================================================*/
1515

1616
#include "absl/status/status.h"
17+
#include "absl/strings/string_view.h"
1718
#include "absl/synchronization/mutex.h"
1819
#include "research/carls/base/file_helper.h"
1920
#include "research/carls/base/proto_helper.h"
@@ -60,6 +61,13 @@ class InProtoKnowledgeBank : public KnowledgeBank {
6061
// Implementation of the Keys interface.
6162
std::vector<absl::string_view> Keys() const override;
6263

64+
// Implementation of the Contains interface.
65+
bool Contains(absl::string_view key) const {
66+
absl::ReaderMutexLock l(&mu_);
67+
return in_proto_config_.embedding_data().embedding_table().contains(
68+
std::string(key));
69+
}
70+
6371
mutable absl::Mutex mu_;
6472
InProtoKnowledgeBankConfig in_proto_config_ ABSL_GUARDED_BY(mu_);
6573

research/carls/knowledge_bank/in_proto_knowledge_bank_test.cc

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -45,20 +45,22 @@ TEST_F(InProtoKnowledgeBankTest, LookupAndUpdate) {
4545
EmbeddingVectorProto value;
4646
value.add_value(1.0f);
4747
value.add_value(2.0f);
48-
EXPECT_TRUE(store->Update("key1", value).ok());
48+
EXPECT_OK(store->Update("key1", value));
4949

5050
EmbeddingVectorProto result;
51-
EXPECT_TRUE(store->Lookup("key1", &result).ok());
51+
EXPECT_OK(store->Lookup("key1", &result));
5252
EXPECT_THAT(result, EqualsProto<EmbeddingVectorProto>(R"pb(
5353
value: 1 value: 2
5454
)pb"));
5555

56-
EXPECT_FALSE(store->Lookup("key2", &result).ok());
56+
EXPECT_NOT_OK(store->Lookup("key2", &result));
5757

5858
// Checks size and keys of embedding.
5959
EXPECT_EQ(1, store->Size());
6060
ASSERT_EQ(1, store->Keys().size());
6161
EXPECT_EQ("key1", store->Keys()[0]);
62+
EXPECT_TRUE(store->Contains("key1"));
63+
EXPECT_FALSE(store->Contains("key2"));
6264
}
6365

6466
TEST_F(InProtoKnowledgeBankTest, BatchLookupAndUpdate) {
@@ -82,7 +84,7 @@ TEST_F(InProtoKnowledgeBankTest, BatchLookupAndUpdate) {
8284
// Checks the BatchUpdate and BatchLookup.
8385
auto statuses = store->BatchUpdate(keys, values);
8486
for (const auto& status : statuses) {
85-
ASSERT_TRUE(status.ok());
87+
ASSERT_OK(status);
8688
}
8789

8890
// Checks the BatchLookup results.
@@ -103,12 +105,14 @@ TEST_F(InProtoKnowledgeBankTest, BatchLookupAndUpdate) {
103105
for (int i = 0; i < batch_size; ++i) {
104106
EXPECT_EQ(absl::StrCat("key", i), store->Keys()[i]);
105107
}
108+
EXPECT_TRUE(store->Contains("key99"));
109+
EXPECT_FALSE(store->Contains("key199"));
106110
}
107111

108112
TEST_F(InProtoKnowledgeBankTest, LookupWithUpdate) {
109113
auto store = CreateDefaultStore(2);
110114
EmbeddingVectorProto result;
111-
ASSERT_TRUE(store->LookupWithUpdate("key1", &result).ok());
115+
ASSERT_OK(store->LookupWithUpdate("key1", &result));
112116
EXPECT_THAT(result, EqualsProto<EmbeddingVectorProto>(R"pb(
113117
tag: "key1"
114118
value: 0
@@ -117,7 +121,7 @@ TEST_F(InProtoKnowledgeBankTest, LookupWithUpdate) {
117121
)pb"));
118122

119123
// Checks that weight is incremented by 1.
120-
ASSERT_TRUE(store->LookupWithUpdate("key1", &result).ok());
124+
ASSERT_OK(store->LookupWithUpdate("key1", &result));
121125
EXPECT_THAT(result, EqualsProto<EmbeddingVectorProto>(R"pb(
122126
tag: "key1"
123127
value: 0
@@ -129,6 +133,7 @@ TEST_F(InProtoKnowledgeBankTest, LookupWithUpdate) {
129133
EXPECT_EQ(1, store->Size());
130134
ASSERT_EQ(1, store->Keys().size());
131135
EXPECT_EQ("key1", store->Keys()[0]);
136+
EXPECT_TRUE(store->Contains("key1"));
132137
}
133138

134139
TEST_F(InProtoKnowledgeBankTest, BatchLookupWithUpdate) {
@@ -144,8 +149,7 @@ TEST_F(InProtoKnowledgeBankTest, BatchLookupWithUpdate) {
144149
std::vector<absl::string_view> keys(str_keys.begin(), str_keys.end());
145150

146151
// Checks the BatchLookupWithUpdate returns the zero initialized values.
147-
std::vector<absl::variant<EmbeddingVectorProto, std::string>>
148-
results;
152+
std::vector<absl::variant<EmbeddingVectorProto, std::string>> results;
149153
store->BatchLookupWithUpdate(keys, &results);
150154
int i = 0;
151155
for (const auto& result : results) {
@@ -170,12 +174,12 @@ TEST_F(InProtoKnowledgeBankTest, Export) {
170174

171175
// Even the time changes, the length should always be the same.
172176
std::string exported_path;
173-
ASSERT_TRUE(store->Export(TempDir(), "", &exported_path).ok());
177+
ASSERT_OK(store->Export(TempDir(), "", &exported_path));
174178
EXPECT_EQ(JoinPath(TempDir(), "embedding_store_meta_data.pbtxt"),
175179
exported_path);
176180

177181
KnowledgeBankCheckpointMetaData meta_data;
178-
ASSERT_TRUE(ReadTextProto(exported_path, &meta_data).ok());
182+
ASSERT_OK(ReadTextProto(exported_path, &meta_data));
179183
EXPECT_EQ(JoinPath(TempDir(), "in_proto_embedding_data.pbbin"),
180184
meta_data.checkpoint_saved_path());
181185
}
@@ -185,20 +189,20 @@ TEST_F(InProtoKnowledgeBankTest, Import) {
185189

186190
// Some updates.
187191
EmbeddingVectorProto result;
188-
EXPECT_TRUE(store->LookupWithUpdate("key1", &result).ok());
189-
EXPECT_TRUE(store->LookupWithUpdate("key2", &result).ok());
190-
EXPECT_TRUE(store->LookupWithUpdate("key3", &result).ok());
191-
EXPECT_TRUE(store->LookupWithUpdate("key2", &result).ok());
192-
EXPECT_TRUE(store->LookupWithUpdate("key2", &result).ok());
192+
EXPECT_OK(store->LookupWithUpdate("key1", &result));
193+
EXPECT_OK(store->LookupWithUpdate("key2", &result));
194+
EXPECT_OK(store->LookupWithUpdate("key3", &result));
195+
EXPECT_OK(store->LookupWithUpdate("key2", &result));
196+
EXPECT_OK(store->LookupWithUpdate("key2", &result));
193197

194198
// Now saves a checkpoint.
195199
std::string exported_path;
196-
ASSERT_TRUE(store->Export(TempDir(), "", &exported_path).ok());
200+
ASSERT_OK(store->Export(TempDir(), "", &exported_path));
197201

198202
// Some updates.
199-
EXPECT_TRUE(store->LookupWithUpdate("key1", &result).ok());
200-
EXPECT_TRUE(store->LookupWithUpdate("key4", &result).ok());
201-
EXPECT_TRUE(store->LookupWithUpdate("key5", &result).ok());
203+
EXPECT_OK(store->LookupWithUpdate("key1", &result));
204+
EXPECT_OK(store->LookupWithUpdate("key4", &result));
205+
EXPECT_OK(store->LookupWithUpdate("key5", &result));
202206

203207
// Checks size and keys of embedding.
204208
EXPECT_EQ(5, store->Size());
@@ -210,7 +214,7 @@ TEST_F(InProtoKnowledgeBankTest, Import) {
210214
EXPECT_EQ("key5", store->Keys()[4]);
211215

212216
// Import previous state.
213-
ASSERT_TRUE(store->Import(exported_path).ok());
217+
ASSERT_OK(store->Import(exported_path));
214218

215219
// Checks size and keys of embedding again.
216220
EXPECT_EQ(3, store->Size());

research/carls/knowledge_bank/knowledge_bank.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,11 @@ class KnowledgeBank {
9494
// Returns the list of keys in the knowledge bank.
9595
virtual std::vector<absl::string_view> Keys() const = 0;
9696

97+
// Check if a given key is already in the knowledge bank or not.
98+
// This is used for probing the bank without affecting weights of the
99+
// embedding.
100+
virtual bool Contains(absl::string_view key) const = 0;
101+
97102
protected:
98103
KnowledgeBank(const KnowledgeBankConfig& config,
99104
const int embedding_dimension);

research/carls/knowledge_bank/knowledge_bank_test.cc

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ class FakeEmbedding : public KnowledgeBank {
8484

8585
std::vector<absl::string_view> Keys() const { return keys_; }
8686

87+
bool Contains(absl::string_view key) const {
88+
return data_table_.embedding_table().contains(std::string(key));
89+
}
90+
8791
private:
8892
InProtoKnowledgeBankConfig::EmbeddingData data_table_;
8993
std::vector<absl::string_view> keys_;
@@ -114,20 +118,25 @@ TEST_F(KnowledgeBankTest, Basic) {
114118
auto store = CreateDefaultStore(10);
115119

116120
EXPECT_EQ(10, store->embedding_dimension());
121+
122+
EmbeddingVectorProto embed;
123+
ASSERT_OK(store->LookupWithUpdate("key1", &embed));
124+
EXPECT_TRUE(store->Contains("key1"));
125+
EXPECT_FALSE(store->Contains("key2"));
117126
}
118127

119128
TEST_F(KnowledgeBankTest, LookupAndUpdate) {
120129
auto store = CreateDefaultStore(2);
121130
EmbeddingInitializer initializer;
122131
initializer.mutable_zero_initializer();
123132
EmbeddingVectorProto value = InitializeEmbedding(2, initializer);
124-
EXPECT_TRUE(store->Update("key1", value).ok());
133+
EXPECT_OK(store->Update("key1", value));
125134

126135
EmbeddingVectorProto result;
127-
EXPECT_TRUE(store->Lookup("key1", &result).ok());
128-
EXPECT_THAT(result, EqualsProto<EmbeddingVectorProto>(R"(
136+
EXPECT_OK(store->Lookup("key1", &result));
137+
EXPECT_THAT(result, EqualsProto<EmbeddingVectorProto>(R"pb(
129138
value: 0 value: 0
130-
)"));
139+
)pb"));
131140
EXPECT_EQ(1, store->Size());
132141
ASSERT_EQ(1, store->Keys().size());
133142
EXPECT_EQ("key1", store->Keys()[0]);
@@ -150,9 +159,9 @@ TEST_F(KnowledgeBankTest, BatchLookupAndUpdate) {
150159
ASSERT_TRUE(
151160
absl::holds_alternative<EmbeddingVectorProto>(value_or_errors[i]));
152161
EXPECT_THAT(absl::get<EmbeddingVectorProto>(value_or_errors[i]),
153-
EqualsProto<EmbeddingVectorProto>(R"(
162+
EqualsProto<EmbeddingVectorProto>(R"pb(
154163
value: 0 value: 0
155-
)"));
164+
)pb"));
156165
}
157166
ASSERT_TRUE(absl::holds_alternative<std::string>(value_or_errors[2]));
158167
EXPECT_EQ("Data not found", absl::get<std::string>(value_or_errors[2]));
@@ -172,9 +181,9 @@ TEST_F(KnowledgeBankTest, BatchLookupWithUpdate) {
172181
ASSERT_TRUE(
173182
absl::holds_alternative<EmbeddingVectorProto>(value_or_errors[i]));
174183
EXPECT_THAT(absl::get<EmbeddingVectorProto>(value_or_errors[i]),
175-
EqualsProto<EmbeddingVectorProto>(R"(
184+
EqualsProto<EmbeddingVectorProto>(R"pb(
176185
value: 0 value: 0
177-
)"));
186+
)pb"));
178187
}
179188
EXPECT_EQ(3, store->Size());
180189
ASSERT_EQ(3, store->Keys().size());
@@ -188,36 +197,36 @@ TEST_F(KnowledgeBankTest, Export) {
188197

189198
// Export to a new dir.
190199
std::string exported_path;
191-
ASSERT_TRUE(store->Export(TempDir(), "", &exported_path).ok());
200+
ASSERT_OK(store->Export(TempDir(), "", &exported_path));
192201
EXPECT_EQ(JoinPath(TempDir(), "embedding_store_meta_data.pbtxt"),
193202
exported_path);
194203

195204
// Export to the same dir again, it overwrites existing checkpoint.
196-
ASSERT_TRUE(store->Export(TempDir(), "", &exported_path).ok());
205+
ASSERT_OK(store->Export(TempDir(), "", &exported_path));
197206
}
198207

199208
TEST_F(KnowledgeBankTest, Import) {
200209
auto store = CreateDefaultStore(2);
201210

202211
// Some updates.
203212
EmbeddingVectorProto result;
204-
EXPECT_TRUE(store->LookupWithUpdate("key1", &result).ok());
205-
EXPECT_TRUE(store->LookupWithUpdate("key2", &result).ok());
206-
EXPECT_TRUE(store->LookupWithUpdate("key3", &result).ok());
207-
EXPECT_TRUE(store->LookupWithUpdate("key2", &result).ok());
208-
EXPECT_TRUE(store->LookupWithUpdate("key2", &result).ok());
213+
EXPECT_OK(store->LookupWithUpdate("key1", &result));
214+
EXPECT_OK(store->LookupWithUpdate("key2", &result));
215+
EXPECT_OK(store->LookupWithUpdate("key3", &result));
216+
EXPECT_OK(store->LookupWithUpdate("key2", &result));
217+
EXPECT_OK(store->LookupWithUpdate("key2", &result));
209218

210219
// Now saves a checkpoint.
211220
std::string exported_path;
212-
ASSERT_TRUE(store->Export(TempDir(), "", &exported_path).ok());
221+
ASSERT_OK(store->Export(TempDir(), "", &exported_path));
213222

214223
// Some updates.
215-
EXPECT_TRUE(store->LookupWithUpdate("key1", &result).ok());
216-
EXPECT_TRUE(store->LookupWithUpdate("key4", &result).ok());
217-
EXPECT_TRUE(store->LookupWithUpdate("key5", &result).ok());
224+
EXPECT_OK(store->LookupWithUpdate("key1", &result));
225+
EXPECT_OK(store->LookupWithUpdate("key4", &result));
226+
EXPECT_OK(store->LookupWithUpdate("key5", &result));
218227

219228
// Import previous state.
220-
ASSERT_TRUE(store->Import(exported_path).ok());
229+
ASSERT_OK(store->Import(exported_path));
221230
}
222231

223232
} // namespace carls

0 commit comments

Comments
 (0)