@@ -314,34 +314,14 @@ def _diagonal_indices(H, W, k):
314314
315315def diag (x , k = 0 ):
316316 x = convert_to_tensor (x )
317-
318- if len (x .shape ) == 2 :
319- return x [_diagonal_indices (* x .shape , k )]
320-
321- elif len (x .shape ) == 1 :
322- N = x .shape [0 ] + abs (k )
323- zeros = mx .zeros ((N , N ))
324- zeros [_diagonal_indices (N , N , k )] = x
325- return zeros
326-
327- else :
328- raise ValueError ("Input must be 1d or 2d" )
317+ if x .dtype in [mx .int64 , mx .uint64 ]:
318+ return mx .diag (x , k = k , stream = mx .Device (type = mx .DeviceType .cpu ))
319+ return mx .diag (x , k = k )
329320
330321
331322def diagonal (x , offset = 0 , axis1 = 0 , axis2 = 1 ):
332323 x = convert_to_tensor (x )
333-
334- ndim = x .ndim
335- axis1 = (ndim + axis1 ) % ndim
336- axis2 = (ndim + axis2 ) % ndim
337-
338- max_axis = builtins .max (axis1 , axis2 )
339- indices = [slice (None ) for _ in range (max_axis + 1 )]
340- indices [axis1 ], indices [axis2 ] = _diagonal_indices (
341- x .shape [axis1 ], x .shape [axis2 ], offset
342- )
343-
344- return x [indices ]
324+ return mx .diagonal (x , offset = offset , axis1 = axis1 , axis2 = axis2 )
345325
346326
347327def diff (x , n = 1 , axis = - 1 ):
0 commit comments