@@ -300,20 +300,28 @@ def resize(
300300 size ,
301301 interpolation = "bilinear" ,
302302 antialias = False ,
303+ crop_to_aspect_ratio = False ,
304+ pad_to_aspect_ratio = False ,
305+ fill_mode = "constant" ,
306+ fill_value = 0.0 ,
303307 data_format = "channels_last" ,
304308):
305309 if antialias :
306310 raise NotImplementedError (
307311 "Antialiasing not implemented for the MLX backend"
308312 )
309-
313+ if pad_to_aspect_ratio and crop_to_aspect_ratio :
314+ raise ValueError (
315+ "Only one of `pad_to_aspect_ratio` & `crop_to_aspect_ratio` "
316+ "can be `True`."
317+ )
310318 if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS .keys ():
311319 raise ValueError (
312320 "Invalid value for argument `interpolation`. Expected of one "
313321 f"{ set (AFFINE_TRANSFORM_INTERPOLATIONS .keys ())} . Received: "
314322 f"interpolation={ interpolation } "
315323 )
316-
324+ target_height , target_width = size
317325 size = tuple (size )
318326 image = convert_to_tensor (image )
319327
@@ -324,6 +332,127 @@ def resize(
324332 f"image.shape={ image .shape } "
325333 )
326334
335+ if crop_to_aspect_ratio :
336+ shape = image .shape
337+ if data_format == "channels_last" :
338+ height , width = shape [- 3 ], shape [- 2 ]
339+ else :
340+ height , width = shape [- 2 ], shape [- 1 ]
341+ crop_height = int (float (width * target_height ) / target_width )
342+ crop_height = min (height , crop_height )
343+ crop_width = int (float (height * target_width ) / target_height )
344+ crop_width = min (width , crop_width )
345+ crop_box_hstart = int (float (height - crop_height ) / 2 )
346+ crop_box_wstart = int (float (width - crop_width ) / 2 )
347+ if data_format == "channels_last" :
348+ if len (image .shape ) == 4 :
349+ image = image [
350+ :,
351+ crop_box_hstart : crop_box_hstart + crop_height ,
352+ crop_box_wstart : crop_box_wstart + crop_width ,
353+ :,
354+ ]
355+ else :
356+ image = image [
357+ crop_box_hstart : crop_box_hstart + crop_height ,
358+ crop_box_wstart : crop_box_wstart + crop_width ,
359+ :,
360+ ]
361+ else :
362+ if len (image .shape ) == 4 :
363+ image = image [
364+ :,
365+ :,
366+ crop_box_hstart : crop_box_hstart + crop_height ,
367+ crop_box_wstart : crop_box_wstart + crop_width ,
368+ ]
369+ else :
370+ image = image [
371+ :,
372+ crop_box_hstart : crop_box_hstart + crop_height ,
373+ crop_box_wstart : crop_box_wstart + crop_width ,
374+ ]
375+ elif pad_to_aspect_ratio :
376+ shape = image .shape
377+ batch_size = image .shape [0 ]
378+ if data_format == "channels_last" :
379+ height , width , channels = shape [- 3 ], shape [- 2 ], shape [- 1 ]
380+ else :
381+ channels , height , width = shape [- 3 ], shape [- 2 ], shape [- 1 ]
382+ pad_height = int (float (width * target_height ) / target_width )
383+ pad_height = max (height , pad_height )
384+ pad_width = int (float (height * target_width ) / target_height )
385+ pad_width = max (width , pad_width )
386+ img_box_hstart = int (float (pad_height - height ) / 2 )
387+ img_box_wstart = int (float (pad_width - width ) / 2 )
388+ if data_format == "channels_last" :
389+ if len (image .shape ) == 4 :
390+ padded_img = (
391+ mx .ones (
392+ (
393+ batch_size ,
394+ pad_height + height ,
395+ pad_width + width ,
396+ channels ,
397+ ),
398+ dtype = image .dtype ,
399+ )
400+ * fill_value
401+ )
402+ padded_img [
403+ :,
404+ img_box_hstart : img_box_hstart + height ,
405+ img_box_wstart : img_box_wstart + width ,
406+ :,
407+ ] = image
408+ else :
409+ padded_img = (
410+ mx .ones (
411+ (pad_height + height , pad_width + width , channels ),
412+ dtype = image .dtype ,
413+ )
414+ * fill_value
415+ )
416+ padded_img [
417+ img_box_hstart : img_box_hstart + height ,
418+ img_box_wstart : img_box_wstart + width ,
419+ :,
420+ ] = image
421+ else :
422+ if len (image .shape ) == 4 :
423+ padded_img = (
424+ mx .ones (
425+ (
426+ batch_size ,
427+ channels ,
428+ pad_height + height ,
429+ pad_width + width ,
430+ ),
431+ dtype = image .dtype ,
432+ )
433+ * fill_value
434+ )
435+ padded_img [
436+ :,
437+ :,
438+ img_box_hstart : img_box_hstart + height ,
439+ img_box_wstart : img_box_wstart + width ,
440+ ] = image
441+ else :
442+ padded_img = (
443+ mx .ones (
444+ (channels , pad_height + height , pad_width + width ),
445+ dtype = image .dtype ,
446+ )
447+ * fill_value
448+ )
449+ padded_img [
450+ :,
451+ img_box_hstart : img_box_hstart + height ,
452+ img_box_wstart : img_box_wstart + width ,
453+ ] = image
454+ image = padded_img
455+
327456 # Change to channels_last
328457 if data_format == "channels_first" :
329458 image = (
0 commit comments