diff --git a/tests/jax/test_permutation.py b/tests/jax/test_permutation.py index 43f2553eed..5bb59c6ed5 100644 --- a/tests/jax/test_permutation.py +++ b/tests/jax/test_permutation.py @@ -19,11 +19,6 @@ from utils import assert_allclose, pytest_parametrize_wrapper -# ============================================================================= -# Test parameter definitions with L0 (fast) and L2 (comprehensive) levels -# ============================================================================= - -# All dispatch/combine test cases ALL_DISPATCH_COMBINE_CASES = [ (128, 5, 128, 3), (1024, 8, 128, 8), @@ -35,7 +30,6 @@ "L2": ALL_DISPATCH_COMBINE_CASES, } -# All sort chunks test cases ALL_SORT_CHUNKS_CASES = [ (8, 4096, 1280), (64, 4096, 4096), @@ -46,7 +40,6 @@ "L2": ALL_SORT_CHUNKS_CASES, } -# All dispatch/combine with padding test cases ALL_DISPATCH_COMBINE_PADDING_CASES = [ (128, 5, 128, 3, 8), (1024, 8, 128, 8, 16), @@ -58,14 +51,12 @@ "L2": ALL_DISPATCH_COMBINE_PADDING_CASES, } -# Dtypes for testing ALL_DTYPES = [jnp.float32, jnp.bfloat16] DTYPES = { "L0": ALL_DTYPES, "L2": ALL_DTYPES, } -# With probs options ALL_WITH_PROBS = [True, False] WITH_PROBS = { "L0": [True], @@ -97,7 +88,9 @@ def reference_make_row_id_map( # Compute total tokens per expert and expert offsets tokens_per_expert = jnp.sum(routing_map, axis=0) - expert_offsets = jnp.concatenate([jnp.array([0]), jnp.cumsum(tokens_per_expert)[:-1]]) + expert_offsets = jnp.concatenate( + [jnp.array([0], dtype=jnp.int32), jnp.cumsum(tokens_per_expert)[:-1].astype(jnp.int32)] + ) # Compute destination rows for all (token, expert) pairs # dest_row[i, j] = expert_offsets[j] + cumsum_per_expert[i, j] - 1 if routed, else -1 @@ -115,7 +108,9 @@ def reference_make_row_id_map( # Gather the sorted destination rows and expert indices using advanced indexing # Create indices for gathering - token_idx = jnp.broadcast_to(jnp.arange(num_tokens)[:, None], (num_tokens, num_experts)) + token_idx = jnp.broadcast_to( + jnp.arange(num_tokens, dtype=jnp.int32)[:, None], (num_tokens, num_experts) + ) sorted_dest_rows = dest_rows_all[token_idx, sorted_expert_indices] # Build row_id_map: [dest_row_0, ..., dest_row_{E-1}, expert_idx_0, ..., expert_idx_{E-1}, n_routed] @@ -373,23 +368,27 @@ def reference_make_chunk_sort_map( Row ID map for chunk sorting of shape [num_tokens,]. """ # Compute source chunk boundaries (cumulative sum of original split_sizes) - src_cumsum = jnp.concatenate([jnp.array([0]), jnp.cumsum(split_sizes)]) + src_cumsum = jnp.concatenate( + [jnp.array([0], dtype=jnp.int32), jnp.cumsum(split_sizes).astype(jnp.int32)] + ) # Compute destination chunk boundaries based on sorted order sorted_sizes = split_sizes[sorted_indices] - dest_cumsum = jnp.concatenate([jnp.array([0]), jnp.cumsum(sorted_sizes)]) + dest_cumsum = jnp.concatenate( + [jnp.array([0], dtype=jnp.int32), jnp.cumsum(sorted_sizes).astype(jnp.int32)] + ) # For each source chunk, compute its destination offset # inverse_indices[i] = position of chunk i in sorted order - inverse_indices = jnp.argsort(sorted_indices) + inverse_indices = jnp.argsort(sorted_indices).astype(jnp.int32) dest_offsets = dest_cumsum[inverse_indices] # Create row_id_map: for each token position, compute its destination # First, figure out which chunk each position belongs to - position_indices = jnp.arange(num_tokens) + position_indices = jnp.arange(num_tokens, dtype=jnp.int32) # chunk_ids[i] = which chunk position i belongs to - chunk_ids = jnp.searchsorted(src_cumsum[1:], position_indices, side="right") + chunk_ids = jnp.searchsorted(src_cumsum[1:], position_indices, side="right").astype(jnp.int32) # within_chunk_offset[i] = position i's offset within its chunk within_chunk_offset = position_indices - src_cumsum[chunk_ids]