Skip to content

Commit 4eeb3ba

Browse files
committed
Fix tensorflow compilation with flatbuffer upgrade and v2.22.0 patch
Signed-off-by: Dom Del Nano <ddelnano@gmail.com>
1 parent 0cdef83 commit 4eeb3ba

File tree

8 files changed

+530
-57
lines changed

8 files changed

+530
-57
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
diff --git a/BUILD.bazel b/BUILD.bazel
2+
index b4f015a0..70848962 100644
3+
--- a/BUILD.bazel
4+
+++ b/BUILD.bazel
5+
@@ -1,5 +1,3 @@
6+
-load("@aspect_rules_js//npm:defs.bzl", "npm_link_package")
7+
-load("@npm//:defs.bzl", "npm_link_all_packages")
8+
load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library")
9+
10+
licenses(["notice"])
11+
@@ -8,12 +6,7 @@ package(
12+
default_visibility = ["//visibility:public"],
13+
)
14+
15+
-npm_link_all_packages(name = "node_modules")
16+
17+
-npm_link_package(
18+
- name = "node_modules/flatbuffers",
19+
- src = "//ts:flatbuffers",
20+
-)
21+
22+
exports_files([
23+
"LICENSE",
24+
@@ -40,11 +33,9 @@ filegroup(
25+
"BUILD.bazel",
26+
"WORKSPACE",
27+
"build_defs.bzl",
28+
- "typescript.bzl",
29+
"//grpc/src/compiler:distribution",
30+
"//reflection:distribution",
31+
"//src:distribution",
32+
- "//ts:distribution",
33+
] + glob([
34+
"include/flatbuffers/*.h",
35+
]),

bazel/external/tensorflow_local_changes.patch

Lines changed: 443 additions & 0 deletions
Large diffs are not rendered by default.

bazel/repositories.bzl

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def _cc_deps():
121121
# Dependencies with native bazel build files.
122122

123123
_bazel_repo("upb")
124+
_bazel_repo("com_google_flatbuffers", patches = ["//bazel/external:flatbuffers_local_changes.patch"], patch_args = ["-p1"])
124125
_bazel_repo("com_google_protobuf", patches = ["//bazel/external:protobuf_text_format_v31_part1.patch", "//bazel/external:protobuf_text_format_v31_part2.patch", "//bazel/external:protobuf_py_proto_library.patch"], patch_args = ["-p1"])
125126
# _bazel_repo("com_google_protobuf", patches = ["//bazel/external:protobuf_text_format.patch", "//bazel/external:protobuf_warning.patch"], patch_args = ["-p1"])
126127
_bazel_repo("com_github_grpc_grpc", patches = ["//bazel/external:grpc_go_toolchain.patch"], patch_args = ["-p1"])
@@ -133,17 +134,11 @@ def _cc_deps():
133134
_bazel_repo("com_github_google_glog")
134135
_bazel_repo("com_google_absl")
135136
_bazel_repo("abseil-cpp") # Alias for gRPC/Protobuf
136-
_bazel_repo("com_google_flatbuffers")
137137
_bazel_repo("cpuinfo", patches = ["//bazel/external:cpuinfo.patch"], patch_args = ["-p1"])
138138
# _bazel_repo("org_tensorflow", patches = ["//bazel/external:tensorflow_disable_llvm.patch", "//bazel/external:tensorflow_disable_mirrors.patch", "//bazel/external:tensorflow_disable_py.patch"], patch_args = ["-p1"])
139-
_bazel_repo("org_tensorflow", patches = ["//bazel/external:tensorflow_disable_py_v2.20.patch"], patch_args = ["-p1"])
139+
# _bazel_repo("org_tensorflow", patches = ["//bazel/external:tensorflow_disable_py_v2.20.patch"], patch_args = ["-p1"])
140+
_bazel_repo("org_tensorflow", patches = ["//bazel/external:tensorflow_local_changes.patch"], patch_args = ["-p1"])
140141

141-
# Stub CUDA repository for TensorFlow (we don't need CUDA support)
142-
native.new_local_repository(
143-
name = "local_config_cuda",
144-
path = "bazel/external/local_config_cuda",
145-
build_file_content = "",
146-
)
147142
_bazel_repo("com_github_neargye_magic_enum")
148143
_bazel_repo("com_github_thoughtspot_threadstacks")
149144
_bazel_repo("com_googlesource_code_re2", patches = ["//bazel/external:re2_warning.patch"], patch_args = ["-p1"])

bazel/repository_locations.bzl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,9 +252,9 @@ REPOSITORY_LOCATIONS = dict(
252252
urls = ["https://github.com/google/farmhash/archive/2f0e005b81e296fa6963e395626137cf729b710c.tar.gz"],
253253
),
254254
com_google_flatbuffers = dict(
255-
sha256 = "e2dc24985a85b278dd06313481a9ca051d048f9474e0f199e372fea3ea4248c9",
256-
strip_prefix = "flatbuffers-2.0.6",
257-
urls = ["https://github.com/google/flatbuffers/archive/refs/tags/v2.0.6.tar.gz"],
255+
sha256 = "4157c5cacdb59737c5d627e47ac26b140e9ee28b1102f812b36068aab728c1ed",
256+
strip_prefix = "flatbuffers-24.3.25",
257+
urls = ["https://github.com/google/flatbuffers/archive/refs/tags/v24.3.25.tar.gz"],
258258
),
259259
com_google_googletest = dict(
260260
sha256 = "65fab701d9829d38cb77c14acdc431d2108bfdbf8979e40eb8ae567edf10b27c",

bazel/toolchain_transitions.bzl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ java_graal_binary, _java_graal_binary_internal = with_cfg(native.java_binary).se
2121
"java_runtime_version", "remotejdk_openjdk_graal_17").build()
2222

2323
cc_clang_binary, _cc_clang_binary_internal = with_cfg(native.cc_binary).set(
24-
Label("@//bazel/cc_toolchains:compiler"), "clang").build()
24+
Label("@//bazel/cc_toolchains:compiler"), "clang").set(
25+
Label("@//bazel/cc_toolchains:libc_version"), "glibc2_36").build()
2526

2627
qemu_interactive_runner, _qemu_interactive_runner_internal = with_cfg(qemu_with_kernel_interactive_runner).set(
2728
Label("@//bazel/cc_toolchains:libc_version"), "glibc2_36").build()

src/carnot/exec/ml/BUILD.bazel

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ pl_cc_library(
2626
"*.h",
2727
],
2828
exclude = [
29-
"transformer_executor.*",
3029
"**/*_test.cc",
3130
"**/*_benchmark.cc",
3231
"**/*_mock.h",
@@ -39,8 +38,8 @@ pl_cc_library(
3938
"//src/shared/types:cc_library",
4039
"@com_github_google_sentencepiece//:libsentencepiece",
4140
"@com_github_tencent_rapidjson//:rapidjson",
42-
# "@org_tensorflow//tensorflow/lite:framework",
43-
# "@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
41+
"@org_tensorflow//tensorflow/lite:framework",
42+
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
4443
"@eigen_archive//:eigen3",
4544
],
4645
)

src/carnot/funcs/builtins/BUILD.bazel

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -103,42 +103,42 @@ pl_cc_test(
103103
],
104104
)
105105

106-
# pl_cc_test(
107-
# name = "ml_ops_test",
108-
# srcs = ["ml_ops_test.cc"],
109-
# args = [
110-
# "--sentencepiece_dir=$(location //:sentencepiece.proto)",
111-
# "--embedding_dir=$(location //:embedding.proto)",
112-
# ],
113-
# data = [
114-
# "//:embedding.proto",
115-
# "//:sentencepiece.proto",
116-
# ],
117-
# deps = [
118-
# ":cc_library",
119-
# "//src/carnot/exec/ml:eigen_testutils",
120-
# "//src/carnot/udf:udf_testutils",
121-
# ],
122-
# )
123-
124-
# pl_cc_binary(
125-
# name = "ml_ops_benchmark",
126-
# testonly = 1,
127-
# srcs = ["ml_ops_benchmark.cc"],
128-
# args = [
129-
# "--sentencepiece_dir=$(location //:sentencepiece.proto)",
130-
# "--embedding_dir=$(location //:embedding.proto)",
131-
# ],
132-
# data = [
133-
# "//:embedding.proto",
134-
# "//:sentencepiece.proto",
135-
# ],
136-
# deps = [
137-
# ":cc_library",
138-
# "//src/common/benchmark:cc_library",
139-
# # "@org_tensorflow//tensorflow/lite:tflite_with_xnnpack",
140-
# ],
141-
# )
106+
pl_cc_test(
107+
name = "ml_ops_test",
108+
srcs = ["ml_ops_test.cc"],
109+
args = [
110+
"--sentencepiece_dir=$(location //:sentencepiece.proto)",
111+
"--embedding_dir=$(location //:embedding.proto)",
112+
],
113+
data = [
114+
"//:embedding.proto",
115+
"//:sentencepiece.proto",
116+
],
117+
deps = [
118+
":cc_library",
119+
"//src/carnot/exec/ml:eigen_testutils",
120+
"//src/carnot/udf:udf_testutils",
121+
],
122+
)
123+
124+
pl_cc_binary(
125+
name = "ml_ops_benchmark",
126+
testonly = 1,
127+
srcs = ["ml_ops_benchmark.cc"],
128+
args = [
129+
"--sentencepiece_dir=$(location //:sentencepiece.proto)",
130+
"--embedding_dir=$(location //:embedding.proto)",
131+
],
132+
data = [
133+
"//:embedding.proto",
134+
"//:sentencepiece.proto",
135+
],
136+
deps = [
137+
":cc_library",
138+
"//src/common/benchmark:cc_library",
139+
# "@org_tensorflow//tensorflow/lite:tflite_with_xnnpack",
140+
],
141+
)
142142

143143
pl_cc_test(
144144
name = "request_path_ops_test",

src/carnot/funcs/builtins/ml_ops.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
#include "src/carnot/exec/ml/coreset.h"
3232
#include "src/carnot/exec/ml/kmeans.h"
3333
#include "src/carnot/exec/ml/sampling.h"
34-
/* #include "src/carnot/exec/ml/transformer_executor.h" */
34+
#include "src/carnot/exec/ml/transformer_executor.h"
3535
#include "src/carnot/udf/model_executor.h"
3636
#include "src/carnot/udf/registry.h"
3737
#include "src/common/base/utils.h"
@@ -53,11 +53,11 @@ class TransformerUDF : public udf::ScalarUDF {
5353
public:
5454
TransformerUDF() : TransformerUDF("/embedding.proto") {}
5555
explicit TransformerUDF(std::string model_proto_path) : model_proto_path_(model_proto_path) {}
56-
StringValue Exec(FunctionContext* /*ctx*/, StringValue /*doc*/) {
57-
/* auto executor = */
58-
/* ctx->model_pool()->GetModelExecutor<exec::ml::TransformerExecutor>(model_proto_path_); */
59-
std::string output = "";
60-
/* executor->Execute(doc, &output); */
56+
StringValue Exec(FunctionContext* ctx, StringValue doc) {
57+
auto executor =
58+
ctx->model_pool()->GetModelExecutor<exec::ml::TransformerExecutor>(model_proto_path_);
59+
std::string output;
60+
executor->Execute(doc, &output);
6161
return output;
6262
}
6363

0 commit comments

Comments
 (0)