@@ -187,27 +187,34 @@ REGISTER_OP("PackSequencesK")
187187 return InvalidArgument (
188188 " `inputs` and `max_lengths` had different numbers of elements" );
189189 }
190- auto inputs = ctx->input_handle_shapes_and_types (0 );
191- std::vector<ShapeHandle> output_dims (inputs->size ());
192- std::vector<ShapeHandle> segmentation_dims (inputs->size ());
193- std::vector<ShapeHandle> position_dims (inputs->size ());
194- for (int i = 0 ; i < inputs->size (); i++) {
195- auto input = inputs->at (i);
196- int rank = ctx->Rank (input.shape );
197- std::vector<DimensionHandle> dims (rank);
198- for (int r = 0 ; r < rank; r++) {
199- dims.push_back (ctx->UnknownDim ());
200- }
201- output_dims.push_back (ctx->MakeShape (dims));
202- segmentation_dims.push_back (
190+ std::vector<ShapeHandle> input_shapes;
191+ TF_RETURN_IF_ERROR (ctx->input (" inputs" , &input_shapes));
192+ std::vector<ShapeHandle> output_shapes;
193+ std::vector<ShapeHandle> segmentation_shapes;
194+ std::vector<ShapeHandle> position_shapes;
195+ for (int i = 0 ; i < input_shapes.size (); i++) {
196+ const auto & input_shape = input_shapes.at (i);
197+ int rank = ctx->Rank (input_shape);
198+ segmentation_shapes.push_back (
203199 ctx->Matrix (ctx->UnknownDim (), ctx->UnknownDim ()));
204- position_dims .push_back (
200+ position_shapes .push_back (
205201 ctx->Matrix (ctx->UnknownDim (), ctx->UnknownDim ()));
202+ if (rank == 2 ) {
203+ output_shapes.push_back (
204+ ctx->MakeShape ({ctx->UnknownDim (), ctx->UnknownDim ()}));
205+ } else if (rank == 3 ) {
206+ output_shapes.push_back (
207+ ctx->MakeShape ({ctx->UnknownDim (), ctx->UnknownDim (),
208+ ctx->Value (ctx->Dim (input_shape, 2 ))}));
209+ } else {
210+ return InvalidArgument (
211+ " Only rank 2 and rank 3 inputs are supported" );
212+ }
206213 }
207- TF_RETURN_IF_ERROR (ctx->set_output (" outputs_packed" , output_dims ));
214+ TF_RETURN_IF_ERROR (ctx->set_output (" outputs_packed" , output_shapes ));
208215 TF_RETURN_IF_ERROR (
209- ctx->set_output (" outputs_segmentation" , segmentation_dims ));
210- TF_RETURN_IF_ERROR (ctx->set_output (" outputs_position" , position_dims ));
216+ ctx->set_output (" outputs_segmentation" , segmentation_shapes ));
217+ TF_RETURN_IF_ERROR (ctx->set_output (" outputs_position" , position_shapes ));
211218 return Status::OK ();
212219 });
213220
0 commit comments