@@ -31,7 +31,6 @@ limitations under the License.
3131#include " tensorflow/core/platform/logging.h"
3232#include " tensorflow/core/platform/macros.h"
3333#include " tensorflow/core/platform/types.h"
34-
3534#include " tensorflow_compression/cc/kernels/range_coder.h"
3635#include " tensorflow_compression/cc/kernels/range_coding_kernels_util.h"
3736
@@ -40,21 +39,21 @@ namespace {
4039namespace errors = tensorflow::errors;
4140namespace gtl = tensorflow::gtl;
4241using tensorflow::DEVICE_CPU;
42+ using tensorflow::int16;
43+ using tensorflow::int32;
44+ using tensorflow::int64;
4345using tensorflow::OpKernel;
4446using tensorflow::OpKernelConstruction;
4547using tensorflow::OpKernelContext;
4648using tensorflow::Status;
49+ using tensorflow::string;
4750using tensorflow::Tensor;
4851using tensorflow::TensorShape;
4952using tensorflow::TensorShapeUtils;
5053using tensorflow::TTypes;
51- using tensorflow::int16;
52- using tensorflow::int32;
53- using tensorflow::int64;
54- using tensorflow::string;
55- using tensorflow::uint8;
5654using tensorflow::uint32;
5755using tensorflow::uint64;
56+ using tensorflow::uint8;
5857
5958// A helper class to iterate over data and cdf simultaneously, while cdf is
6059// broadcasted to data.
@@ -151,14 +150,42 @@ Status CheckCdfShape(const TensorShape& data_shape,
151150 return Status::OK ();
152151}
153152
154- // Non-incremental encoder op -------------------------------------------------
153+ tensorflow::Status CheckCdfValues (int precision,
154+ const tensorflow::Tensor& cdf_tensor) {
155+ const auto cdf = cdf_tensor.flat_inner_dims <int32, 2 >();
156+ const auto size = cdf.dimension (1 );
157+ if (size <= 2 ) {
158+ return errors::InvalidArgument (" CDF size should be > 2: " , size);
159+ }
160+
161+ const int32 upper_bound = 1 << precision;
162+ for (int64 i = 0 ; i < cdf.dimension (0 ); ++i) {
163+ auto slice = tensorflow::gtl::ArraySlice<int32>(&cdf (i, 0 ), size);
164+ if (slice[0 ] != 0 || slice[size - 1 ] != upper_bound) {
165+ return errors::InvalidArgument (" CDF should start from 0 and end at " ,
166+ upper_bound, " : cdf[0]=" , slice[0 ],
167+ " , cdf[^1]=" , slice[size - 1 ]);
168+ }
169+ for (int64 j = 0 ; j + 1 < size; ++j) {
170+ if (slice[j + 1 ] <= slice[j]) {
171+ return errors::InvalidArgument (" CDF is not monotonic" );
172+ }
173+ }
174+ }
175+ return tensorflow::Status::OK ();
176+ }
177+
155178class RangeEncodeOp : public OpKernel {
156179 public:
157180 explicit RangeEncodeOp (OpKernelConstruction* context) : OpKernel(context) {
158181 OP_REQUIRES_OK (context, context->GetAttr (" precision" , &precision_));
159182 OP_REQUIRES (context, 0 < precision_ && precision_ <= 16 ,
160183 errors::InvalidArgument (" `precision` must be in [1, 16]: " ,
161184 precision_));
185+ OP_REQUIRES_OK (context, context->GetAttr (" debug_level" , &debug_level_));
186+ OP_REQUIRES (context, debug_level_ == 0 || debug_level_ == 1 ,
187+ errors::InvalidArgument (" `debug_level` must be 0 or 1: " ,
188+ debug_level_));
162189 }
163190
164191 void Compute (OpKernelContext* context) override {
@@ -167,6 +194,10 @@ class RangeEncodeOp : public OpKernel {
167194
168195 OP_REQUIRES_OK (context, CheckCdfShape (data.shape (), cdf.shape ()));
169196
197+ if (debug_level_ > 0 ) {
198+ OP_REQUIRES_OK (context, CheckCdfValues (precision_, cdf));
199+ }
200+
170201 std::vector<int64> data_shape, cdf_shape;
171202 OP_REQUIRES_OK (
172203 context, MergeAxes (data.shape (), cdf.shape (), &data_shape, &cdf_shape));
@@ -177,10 +208,12 @@ class RangeEncodeOp : public OpKernel {
177208 string* output = &output_tensor->scalar <string>()();
178209
179210 switch (data_shape.size ()) {
180- #define RANGE_ENCODE_CASE (dims ) \
181- case dims: { \
182- RangeEncodeImpl<dims>(data.flat <int16>(), data_shape, \
183- cdf.flat_inner_dims <int32, 2 >(), cdf_shape, output); \
211+ #define RANGE_ENCODE_CASE (dims ) \
212+ case dims: { \
213+ OP_REQUIRES_OK (context, \
214+ RangeEncodeImpl<dims>(data.flat <int16>(), data_shape, \
215+ cdf.flat_inner_dims <int32, 2 >(), \
216+ cdf_shape, output)); \
184217 } break
185218 RANGE_ENCODE_CASE (1 );
186219 RANGE_ENCODE_CASE (2 );
@@ -199,10 +232,11 @@ class RangeEncodeOp : public OpKernel {
199232
200233 private:
201234 template <int N>
202- void RangeEncodeImpl (TTypes<int16>::ConstFlat data,
203- gtl::ArraySlice<int64> data_shape,
204- TTypes<int32>::ConstMatrix cdf,
205- gtl::ArraySlice<int64> cdf_shape, string* output) const {
235+ tensorflow::Status RangeEncodeImpl (TTypes<int16>::ConstFlat data,
236+ gtl::ArraySlice<int64> data_shape,
237+ TTypes<int32>::ConstMatrix cdf,
238+ gtl::ArraySlice<int64> cdf_shape,
239+ string* output) const {
206240 const int64 data_size = data.size ();
207241 const int64 cdf_size = cdf.size ();
208242 const int64 chip_size = cdf.dimension (1 );
@@ -214,8 +248,15 @@ class RangeEncodeOp : public OpKernel {
214248 const auto pair = view.Next ();
215249
216250 const int64 index = *pair.first ;
217- DCHECK_GE (index, 0 );
218- DCHECK_LT (index + 1 , chip_size);
251+ if (debug_level_ > 0 ) {
252+ if (index < 0 || chip_size <= index + 1 ) {
253+ return errors::InvalidArgument (" 'data' value not in [0, " ,
254+ chip_size - 1 , " ): value=" , index);
255+ }
256+ } else {
257+ DCHECK_GE (index, 0 );
258+ DCHECK_LT (index + 1 , chip_size);
259+ }
219260
220261 const int32* cdf_slice = pair.second ;
221262 DCHECK_LE (cdf_slice + chip_size, cdf.data () + cdf_size);
@@ -226,21 +267,26 @@ class RangeEncodeOp : public OpKernel {
226267 }
227268
228269 encoder.Finalize (output);
270+ return tensorflow::Status::OK ();
229271 }
230272
231273 int precision_;
274+ int debug_level_;
232275};
233276
234277REGISTER_KERNEL_BUILDER (Name(" RangeEncode" ).Device(DEVICE_CPU), RangeEncodeOp);
235278
236- // Non-incremental decoder op -------------------------------------------------
237279class RangeDecodeOp : public OpKernel {
238280 public:
239281 explicit RangeDecodeOp (OpKernelConstruction* context) : OpKernel(context) {
240282 OP_REQUIRES_OK (context, context->GetAttr (" precision" , &precision_));
241283 OP_REQUIRES (context, 0 < precision_ && precision_ <= 16 ,
242284 errors::InvalidArgument (" `precision` must be in [1, 16]: " ,
243285 precision_));
286+ OP_REQUIRES_OK (context, context->GetAttr (" debug_level" , &debug_level_));
287+ OP_REQUIRES (context, debug_level_ == 0 || debug_level_ == 1 ,
288+ errors::InvalidArgument (" `debug_level` must be 0 or 1: " ,
289+ debug_level_));
244290 }
245291
246292 void Compute (OpKernelContext* context) override {
@@ -254,11 +300,16 @@ class RangeDecodeOp : public OpKernel {
254300 OP_REQUIRES (context, TensorShapeUtils::IsVector (shape.shape ()),
255301 errors::InvalidArgument (" Invalid `shape` shape: " ,
256302 shape.shape ().DebugString ()));
303+
257304 TensorShape output_shape;
258305 OP_REQUIRES_OK (context, TensorShapeUtils::MakeShape (shape.vec <int32>(),
259306 &output_shape));
260307 OP_REQUIRES_OK (context, CheckCdfShape (output_shape, cdf.shape ()));
261308
309+ if (debug_level_ > 0 ) {
310+ OP_REQUIRES_OK (context, CheckCdfValues (precision_, cdf));
311+ }
312+
262313 std::vector<int64> data_shape, cdf_shape;
263314 OP_REQUIRES_OK (
264315 context, MergeAxes (output_shape, cdf.shape (), &data_shape, &cdf_shape));
@@ -269,10 +320,12 @@ class RangeDecodeOp : public OpKernel {
269320 OP_REQUIRES_OK (context, context->allocate_output (0 , output_shape, &output));
270321
271322 switch (data_shape.size ()) {
272- #define RANGE_DECODE_CASE (dim ) \
273- case dim: { \
274- RangeDecodeImpl<dim>(output->flat <int16>(), data_shape, \
275- cdf.flat_inner_dims <int32>(), cdf_shape, encoded); \
323+ #define RANGE_DECODE_CASE (dim ) \
324+ case dim: { \
325+ OP_REQUIRES_OK ( \
326+ context, RangeDecodeImpl<dim>(output->flat <int16>(), data_shape, \
327+ cdf.flat_inner_dims <int32>(), cdf_shape, \
328+ encoded)); \
276329 } break
277330 RANGE_DECODE_CASE (1 );
278331 RANGE_DECODE_CASE (2 );
@@ -291,11 +344,11 @@ class RangeDecodeOp : public OpKernel {
291344
292345 private:
293346 template <int N>
294- void RangeDecodeImpl (TTypes<int16>::Flat output,
295- gtl::ArraySlice<int64> output_shape,
296- TTypes<int32>::ConstMatrix cdf,
297- gtl::ArraySlice<int64> cdf_shape,
298- const string& encoded) const {
347+ tensorflow::Status RangeDecodeImpl (TTypes<int16>::Flat output,
348+ gtl::ArraySlice<int64> output_shape,
349+ TTypes<int32>::ConstMatrix cdf,
350+ gtl::ArraySlice<int64> cdf_shape,
351+ const string& encoded) const {
299352 BroadcastRange<int16, int32, N> view{output.data (), output_shape,
300353 cdf.data (), cdf_shape};
301354
@@ -315,11 +368,13 @@ class RangeDecodeOp : public OpKernel {
315368 const int32* cdf_slice = pair.second ;
316369 DCHECK_LE (cdf_slice + chip_size, cdf.data () + cdf_size);
317370
318- *data = decoder.Decode (gtl::ArraySlice<int32> {cdf_slice, chip_size});
371+ *data = decoder.Decode ({cdf_slice, chip_size});
319372 }
373+ return tensorflow::Status::OK ();
320374 }
321375
322376 int precision_;
377+ int debug_level_;
323378};
324379
325380REGISTER_KERNEL_BUILDER (Name(" RangeDecode" ).Device(DEVICE_CPU), RangeDecodeOp);
0 commit comments