Skip to content

Commit 01bb1fb

Browse files
Neural-Link Teamtensorflow-copybara
authored andcommitted
Add an increment_last_dim() util function for handling the logit layer with dot product.
PiperOrigin-RevId: 368308163
1 parent 428ee71 commit 01bb1fb

File tree

4 files changed

+157
-0
lines changed

4 files changed

+157
-0
lines changed

research/carls/util/BUILD

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Description:
16+
# Build rules for the utitility component of CARLS.
17+
18+
# Placeholder for internal Python strict & test compatibility macro.
19+
20+
package(
21+
default_visibility = ["//research/carls:internal"],
22+
licenses = ["notice"], # Apache 2.0
23+
)
24+
25+
py_library(
26+
name = "array_ops",
27+
srcs = ["array_ops.py"],
28+
srcs_version = "PY3",
29+
deps = [
30+
# package tensorflow
31+
],
32+
)
33+
34+
py_test(
35+
name = "array_ops_test",
36+
srcs = ["array_ops_test.py"],
37+
python_version = "PY3",
38+
srcs_version = "PY3",
39+
deps = [
40+
":array_ops",
41+
# package tensorflow
42+
],
43+
)

research/carls/util/BUILD.oss

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Description:
16+
# Build rules for CARLS APIs in Neural Structured Learning.
17+
18+
package(
19+
default_visibility = ["//research/carls:internal"],
20+
licenses = ["notice"], # Apache 2.0
21+
)
22+
23+
py_library(
24+
name = "array_ops",
25+
srcs = ["array_ops.py"],
26+
srcs_version = "PY3",
27+
deps = [
28+
# package tensorflow
29+
],
30+
)
31+
32+
py_test(
33+
name = "array_ops_test",
34+
size = "small",
35+
srcs = ["array_ops_test.py"],
36+
python_version = "PY3",
37+
srcs_version = "PY3",
38+
deps = [
39+
# package tensorflow
40+
],
41+
)

research/carls/util/array_ops.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Array related ops."""
15+
16+
import tensorflow as tf
17+
18+
19+
def increment_last_dim(input_tensor: tf.Tensor,
20+
default_value: float) -> tf.Tensor:
21+
"""Grows the size of last dimension of given `input_tensor` by one.
22+
23+
Examples:
24+
- [[1, 2], [3, 4]] -> [[1, 2, 1], [3, 4, 1]] (default_value = 1).
25+
- [1, 2, 3] -> [1, 2, 3, 4] (default_value = 4).
26+
27+
Args:
28+
input_tensor: a float tf.Tensor whose last dimension is to be incremented.
29+
default_value: a float value denoting the default value for the increased
30+
part.
31+
32+
Returns:
33+
A new `tf.Tensor` with increased last dimension size.
34+
"""
35+
input_tensor = tf.dtypes.cast(input_tensor, tf.float32)
36+
inc_tensor = tf.ones(tf.shape(input_tensor)[:-1])
37+
inc_tensor = tf.expand_dims(inc_tensor, -1) * default_value
38+
return tf.concat([input_tensor, inc_tensor], axis=-1)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Tests for neural_structured_learning.research.carls.util.array_ops."""
15+
16+
from research.carls.util import array_ops
17+
import tensorflow as tf
18+
19+
20+
class ArrayOpsTest(tf.test.TestCase):
21+
22+
def test_increment_last_dim(self):
23+
# 1D case
24+
input_tensor = tf.constant([2])
25+
new_tensor = array_ops.increment_last_dim(input_tensor, 1)
26+
self.assertAllClose([2, 1], new_tensor.numpy())
27+
28+
# 2D case
29+
input_tensor = tf.constant([[1, 2], [3, 4]])
30+
new_tensor = array_ops.increment_last_dim(input_tensor, 10)
31+
self.assertAllClose([[1, 2, 10], [3, 4, 10]], new_tensor.numpy())
32+
33+
34+
if __name__ == '__main__':
35+
tf.test.main()

0 commit comments

Comments
 (0)