@@ -16,6 +16,7 @@ limitations under the License.
1616#include " tensorflow_compression/cc/kernels/range_coder_kernels.h"
1717
1818#include < cstdint>
19+ #include < memory>
1920#include < string>
2021#include < utility>
2122
@@ -114,24 +115,9 @@ Status IndexCDFMatrix(const TTypes<int32_t>::ConstMatrix& table,
114115
115116class RangeEncoderInterface final : public EntropyEncoderInterface {
116117 public:
117- static Status MakeShared (const Tensor lookup,
118- std::shared_ptr<EntropyEncoderInterface>* ptr) {
119- Status status;
120- RangeEncoderInterface* re = new RangeEncoderInterface (lookup);
121- if (lookup.dims () == 1 ) {
122- status = IndexCDFVector (lookup.flat <int32_t >(), &re->lookup_ );
123- } else if (lookup.dims () == 2 ) {
124- status = IndexCDFMatrix (lookup.matrix <int32_t >(), &re->lookup_ );
125- } else {
126- status = errors::InvalidArgument (" `lookup` must be rank 1 or 2." );
127- }
128- if (status.ok ()) {
129- ptr->reset (re);
130- } else {
131- delete re;
132- }
133- return status;
134- }
118+ RangeEncoderInterface (absl::Span<const absl::Span<const int32_t >> lookup,
119+ Tensor hold)
120+ : lookup_(lookup.begin(), lookup.end()), hold_(std::move(hold)) {}
135121
136122 Status Encode (int32_t index, int32_t value) override {
137123 TF_RETURN_IF_ERROR (CheckInRange (" index" , index, 0 , lookup_.size ()));
@@ -153,8 +139,6 @@ class RangeEncoderInterface final : public EntropyEncoderInterface {
153139 }
154140
155141 private:
156- explicit RangeEncoderInterface (Tensor lookup) : hold_(std::move(lookup)) {}
157-
158142 void OverflowEncode (const absl::Span<const int32_t > row, int32_t value) {
159143 const int32_t max_value = row.size () - 3 ;
160144 const int32_t sign = value < 0 ;
@@ -193,24 +177,12 @@ class RangeEncoderInterface final : public EntropyEncoderInterface {
193177
194178class RangeDecoderInterface final : public EntropyDecoderInterface {
195179 public:
196- static Status MakeShared (absl::string_view encoded, const Tensor lookup,
197- std::shared_ptr<EntropyDecoderInterface>* ptr) {
198- Status status;
199- RangeDecoderInterface* rd = new RangeDecoderInterface (encoded, lookup);
200- if (lookup.dims () == 1 ) {
201- status = IndexCDFVector (lookup.flat <int32_t >(), &rd->lookup_ );
202- } else if (lookup.dims () == 2 ) {
203- status = IndexCDFMatrix (lookup.matrix <int32_t >(), &rd->lookup_ );
204- } else {
205- status = errors::InvalidArgument (" `lookup` must be rank 1 or 2." );
206- }
207- if (status.ok ()) {
208- ptr->reset (rd);
209- } else {
210- delete rd;
211- }
212- return status;
213- }
180+ RangeDecoderInterface (absl::string_view encoded,
181+ absl::Span<const absl::Span<const int32_t >> lookup,
182+ Tensor hold)
183+ : lookup_(lookup.begin(), lookup.end()),
184+ decoder_ (encoded),
185+ hold_(std::move(hold)) {}
214186
215187 Status Decode (int32_t index, int32_t * output) override {
216188 TF_RETURN_IF_ERROR (CheckInRange (" index" , index, 0 , lookup_.size ()));
@@ -232,9 +204,6 @@ class RangeDecoderInterface final : public EntropyDecoderInterface {
232204 }
233205
234206 private:
235- RangeDecoderInterface (absl::string_view encoded, Tensor lookup)
236- : decoder_(encoded), hold_(std::move(lookup)) {}
237-
238207 int32_t OverflowDecode (const absl::Span<const int32_t > row) {
239208 constexpr int32_t binary_uniform_cdf[] = {0 , 1 , 2 };
240209 const int32_t max_value = row.size () - 3 ;
@@ -313,11 +282,21 @@ class CreateRangeEncoderOp : public tensorflow::OpKernel {
313282 context->allocate_output (0 , handle_shape, &output_tensor));
314283
315284 const Tensor& lookup = context->input (1 );
285+ OP_REQUIRES (context, lookup.dims () == 1 || lookup.dims () == 2 ,
286+ errors::InvalidArgument (" `lookup` must be rank 1 or 2." ));
287+
288+ std::vector<absl::Span<const int32_t >> table;
289+ if (lookup.dims () == 1 ) {
290+ OP_REQUIRES_OK (context, IndexCDFVector (lookup.flat <int32_t >(), &table));
291+ } else {
292+ DCHECK_EQ (lookup.dims (), 2 );
293+ OP_REQUIRES_OK (context, IndexCDFMatrix (lookup.matrix <int32_t >(), &table));
294+ }
295+
316296 auto output = output_tensor->flat <Variant>();
317297 for (int64_t i = 0 ; i < output.size (); ++i) {
318298 EntropyEncoderVariant wrap;
319- OP_REQUIRES_OK (context,
320- RangeEncoderInterface::MakeShared (lookup, &wrap.encoder ));
299+ wrap.encoder = std::make_shared<RangeEncoderInterface>(table, lookup);
321300 output (i) = std::move (wrap);
322301 }
323302 }
@@ -388,10 +367,10 @@ class EntropyEncodeChannelOp : public tensorflow::OpKernel {
388367 context->SetStatus (status); \
389368 return ; \
390369 }
391- #define REQUIRES_OK (status ) \
392- { \
393- auto s = (status); \
394- REQUIRES (s.ok (), s); \
370+ #define REQUIRES_OK (status ) \
371+ { \
372+ auto s = (status); \
373+ REQUIRES (s.ok (), s); \
395374 }
396375
397376 const int64_t num_elements = value.dimension (1 );
@@ -484,10 +463,10 @@ class EntropyEncodeIndexOp : public tensorflow::OpKernel {
484463 context->SetStatus (status); \
485464 return ; \
486465 }
487- #define REQUIRES_OK (status ) \
488- { \
489- auto s = (status); \
490- REQUIRES (s.ok (), s); \
466+ #define REQUIRES_OK (status ) \
467+ { \
468+ auto s = (status); \
469+ REQUIRES (s.ok (), s); \
491470 }
492471
493472 const int64_t num_elements = value.dimension (1 );
@@ -560,11 +539,22 @@ class CreateRangeDecoderOp : public tensorflow::OpKernel {
560539 &output_tensor));
561540
562541 const Tensor& lookup = context->input (1 );
542+ OP_REQUIRES (context, lookup.dims () == 1 || lookup.dims () == 2 ,
543+ errors::InvalidArgument (" `lookup` must be rank 1 or 2." ));
544+
545+ std::vector<absl::Span<const int32_t >> table;
546+ if (lookup.dims () == 1 ) {
547+ OP_REQUIRES_OK (context, IndexCDFVector (lookup.flat <int32_t >(), &table));
548+ } else {
549+ DCHECK_EQ (lookup.dims (), 2 );
550+ OP_REQUIRES_OK (context, IndexCDFMatrix (lookup.matrix <int32_t >(), &table));
551+ }
552+
563553 auto output = output_tensor->flat <Variant>();
564554 for (int64_t i = 0 ; i < output.size (); ++i) {
565555 EntropyDecoderVariant wrap;
566- OP_REQUIRES_OK (context, RangeDecoderInterface::MakeShared (
567- encoded (i), lookup, &wrap. decoder ) );
556+ wrap. decoder =
557+ std::make_shared<RangeDecoderInterface>( encoded (i), table, lookup );
568558 wrap.holder = encoded_tensor;
569559 output (i) = std::move (wrap);
570560 }
@@ -636,10 +626,10 @@ class EntropyDecodeChannelOp : public tensorflow::OpKernel {
636626 context->SetStatus (status); \
637627 return ; \
638628 }
639- #define REQUIRES_OK (status ) \
640- { \
641- auto s = (status); \
642- REQUIRES (s.ok (), s); \
629+ #define REQUIRES_OK (status ) \
630+ { \
631+ auto s = (status); \
632+ REQUIRES (s.ok (), s); \
643633 }
644634
645635 const int64_t num_elements = output.dimension (1 );
@@ -736,10 +726,10 @@ class EntropyDecodeIndexOp : public tensorflow::OpKernel {
736726 context->SetStatus (status); \
737727 return ; \
738728 }
739- #define REQUIRES_OK (status ) \
740- { \
741- auto s = (status); \
742- REQUIRES (s.ok (), s); \
729+ #define REQUIRES_OK (status ) \
730+ { \
731+ auto s = (status); \
732+ REQUIRES (s.ok (), s); \
743733 }
744734
745735 const int64_t num_elements = output.dimension (1 );
0 commit comments