Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 15 additions & 16 deletions tests/jax/test_permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,6 @@
from utils import assert_allclose, pytest_parametrize_wrapper


# =============================================================================
# Test parameter definitions with L0 (fast) and L2 (comprehensive) levels
# =============================================================================

# All dispatch/combine test cases
ALL_DISPATCH_COMBINE_CASES = [
(128, 5, 128, 3),
(1024, 8, 128, 8),
Expand All @@ -35,7 +30,6 @@
"L2": ALL_DISPATCH_COMBINE_CASES,
}

# All sort chunks test cases
ALL_SORT_CHUNKS_CASES = [
(8, 4096, 1280),
(64, 4096, 4096),
Expand All @@ -46,7 +40,6 @@
"L2": ALL_SORT_CHUNKS_CASES,
}

# All dispatch/combine with padding test cases
ALL_DISPATCH_COMBINE_PADDING_CASES = [
(128, 5, 128, 3, 8),
(1024, 8, 128, 8, 16),
Expand All @@ -58,14 +51,12 @@
"L2": ALL_DISPATCH_COMBINE_PADDING_CASES,
}

# Dtypes for testing
ALL_DTYPES = [jnp.float32, jnp.bfloat16]
DTYPES = {
"L0": ALL_DTYPES,
"L2": ALL_DTYPES,
}

# With probs options
ALL_WITH_PROBS = [True, False]
WITH_PROBS = {
"L0": [True],
Expand Down Expand Up @@ -97,7 +88,9 @@ def reference_make_row_id_map(

# Compute total tokens per expert and expert offsets
tokens_per_expert = jnp.sum(routing_map, axis=0)
expert_offsets = jnp.concatenate([jnp.array([0]), jnp.cumsum(tokens_per_expert)[:-1]])
expert_offsets = jnp.concatenate(
[jnp.array([0], dtype=jnp.int32), jnp.cumsum(tokens_per_expert)[:-1].astype(jnp.int32)]
)

# Compute destination rows for all (token, expert) pairs
# dest_row[i, j] = expert_offsets[j] + cumsum_per_expert[i, j] - 1 if routed, else -1
Expand All @@ -115,7 +108,9 @@ def reference_make_row_id_map(

# Gather the sorted destination rows and expert indices using advanced indexing
# Create indices for gathering
token_idx = jnp.broadcast_to(jnp.arange(num_tokens)[:, None], (num_tokens, num_experts))
token_idx = jnp.broadcast_to(
jnp.arange(num_tokens, dtype=jnp.int32)[:, None], (num_tokens, num_experts)
)
sorted_dest_rows = dest_rows_all[token_idx, sorted_expert_indices]

# Build row_id_map: [dest_row_0, ..., dest_row_{E-1}, expert_idx_0, ..., expert_idx_{E-1}, n_routed]
Expand Down Expand Up @@ -373,23 +368,27 @@ def reference_make_chunk_sort_map(
Row ID map for chunk sorting of shape [num_tokens,].
"""
# Compute source chunk boundaries (cumulative sum of original split_sizes)
src_cumsum = jnp.concatenate([jnp.array([0]), jnp.cumsum(split_sizes)])
src_cumsum = jnp.concatenate(
[jnp.array([0], dtype=jnp.int32), jnp.cumsum(split_sizes).astype(jnp.int32)]
)

# Compute destination chunk boundaries based on sorted order
sorted_sizes = split_sizes[sorted_indices]
dest_cumsum = jnp.concatenate([jnp.array([0]), jnp.cumsum(sorted_sizes)])
dest_cumsum = jnp.concatenate(
[jnp.array([0], dtype=jnp.int32), jnp.cumsum(sorted_sizes).astype(jnp.int32)]
)

# For each source chunk, compute its destination offset
# inverse_indices[i] = position of chunk i in sorted order
inverse_indices = jnp.argsort(sorted_indices)
inverse_indices = jnp.argsort(sorted_indices).astype(jnp.int32)
dest_offsets = dest_cumsum[inverse_indices]

# Create row_id_map: for each token position, compute its destination
# First, figure out which chunk each position belongs to
position_indices = jnp.arange(num_tokens)
position_indices = jnp.arange(num_tokens, dtype=jnp.int32)

# chunk_ids[i] = which chunk position i belongs to
chunk_ids = jnp.searchsorted(src_cumsum[1:], position_indices, side="right")
chunk_ids = jnp.searchsorted(src_cumsum[1:], position_indices, side="right").astype(jnp.int32)

# within_chunk_offset[i] = position i's offset within its chunk
within_chunk_offset = position_indices - src_cumsum[chunk_ids]
Expand Down
Loading