Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 3f12173

Browse files
0x0539copybara-github
authored andcommitted
A more generic packing op for seqio
PiperOrigin-RevId: 373172087
1 parent 51256a2 commit 3f12173

File tree

1 file changed

+24
-17
lines changed

1 file changed

+24
-17
lines changed

tensor2tensor/data_generators/ops/pack_sequences_ops.cc

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -187,27 +187,34 @@ REGISTER_OP("PackSequencesK")
187187
return InvalidArgument(
188188
"`inputs` and `max_lengths` had different numbers of elements");
189189
}
190-
auto inputs = ctx->input_handle_shapes_and_types(0);
191-
std::vector<ShapeHandle> output_dims(inputs->size());
192-
std::vector<ShapeHandle> segmentation_dims(inputs->size());
193-
std::vector<ShapeHandle> position_dims(inputs->size());
194-
for (int i = 0; i < inputs->size(); i++) {
195-
auto input = inputs->at(i);
196-
int rank = ctx->Rank(input.shape);
197-
std::vector<DimensionHandle> dims(rank);
198-
for (int r = 0; r < rank; r++) {
199-
dims.push_back(ctx->UnknownDim());
200-
}
201-
output_dims.push_back(ctx->MakeShape(dims));
202-
segmentation_dims.push_back(
190+
std::vector<ShapeHandle> input_shapes;
191+
TF_RETURN_IF_ERROR(ctx->input("inputs", &input_shapes));
192+
std::vector<ShapeHandle> output_shapes;
193+
std::vector<ShapeHandle> segmentation_shapes;
194+
std::vector<ShapeHandle> position_shapes;
195+
for (int i = 0; i < input_shapes.size(); i++) {
196+
const auto& input_shape = input_shapes.at(i);
197+
int rank = ctx->Rank(input_shape);
198+
segmentation_shapes.push_back(
203199
ctx->Matrix(ctx->UnknownDim(), ctx->UnknownDim()));
204-
position_dims.push_back(
200+
position_shapes.push_back(
205201
ctx->Matrix(ctx->UnknownDim(), ctx->UnknownDim()));
202+
if (rank == 2) {
203+
output_shapes.push_back(
204+
ctx->MakeShape({ctx->UnknownDim(), ctx->UnknownDim()}));
205+
} else if (rank == 3) {
206+
output_shapes.push_back(
207+
ctx->MakeShape({ctx->UnknownDim(), ctx->UnknownDim(),
208+
ctx->Value(ctx->Dim(input_shape, 2))}));
209+
} else {
210+
return InvalidArgument(
211+
"Only rank 2 and rank 3 inputs are supported");
212+
}
206213
}
207-
TF_RETURN_IF_ERROR(ctx->set_output("outputs_packed", output_dims));
214+
TF_RETURN_IF_ERROR(ctx->set_output("outputs_packed", output_shapes));
208215
TF_RETURN_IF_ERROR(
209-
ctx->set_output("outputs_segmentation", segmentation_dims));
210-
TF_RETURN_IF_ERROR(ctx->set_output("outputs_position", position_dims));
216+
ctx->set_output("outputs_segmentation", segmentation_shapes));
217+
TF_RETURN_IF_ERROR(ctx->set_output("outputs_position", position_shapes));
211218
return Status::OK();
212219
});
213220

0 commit comments

Comments
 (0)