Skip to content

Conversation

@tdophung
Copy link
Collaborator

@tdophung tdophung commented Jan 7, 2026

Description

Continure to do what PR#2566 was doing.

Fix 50% comparison mismatch in sort_chunks_by_index
When using jnp.arange to initialize array, it could be ambiguous depending on the platform to use int64 or int32, this is to explicitly specify the dtype to eliminate ambiguity, which might have caused data to be misinterpreted and read as int32 when initialized as int64, causing a 50% data mismatch

Fixes # (issue)
50% comparison mismatch in sort_chunks_by_index

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Make array init and ops specify int32 dtype for test_sort_chunk_by_index

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung
Copy link
Collaborator Author

tdophung commented Jan 7, 2026

/te-ci L0 jax

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 7, 2026

Greptile Summary

This PR continues the work from #2566 to fix platform-dependent dtype ambiguity issues in JAX array operations. The changes add explicit int32 dtype casting to two additional locations (inverse_indices and chunk_ids) in the reference_make_chunk_sort_map function, and removes unnecessary section comment blocks throughout the file.

Key changes:

  • Added .astype(jnp.int32) to jnp.argsort(sorted_indices) on line 383
  • Added .astype(jnp.int32) to jnp.searchsorted(...) on line 391
  • Removed comment blocks that labeled test parameter sections

Why this matters:
When using jnp.arange and similar operations without explicit dtype, JAX may use int64 or int32 depending on the platform. This can cause data to be misinterpreted when arrays initialized as int64 are read as int32, leading to 50% data mismatch in comparisons as reported in issue #2566.

Confidence Score: 5/5

  • This PR is safe to merge - it makes targeted fixes to dtype specifications in test code
  • The changes are minimal, well-targeted, and follow the same pattern established in PR Fix 50% comparison mismatch in sort_chunks_by_index  #2566. They only affect test reference implementations, not production code. The explicit dtype casts eliminate platform-dependent behavior and improve test reliability.
  • No files require special attention

Important Files Changed

Filename Overview
tests/jax/test_permutation.py Added explicit int32 dtype casts to inverse_indices and chunk_ids, removed unnecessary comment blocks

Sequence Diagram

sequenceDiagram
    participant Test as Test Function
    participant RefFunc as reference_make_chunk_sort_map
    participant JAX as JAX Operations
    
    Test->>RefFunc: Call with split_sizes, sorted_indices
    RefFunc->>JAX: jnp.cumsum(split_sizes)
    JAX-->>RefFunc: Return cumsum array
    RefFunc->>RefFunc: Cast to int32 with .astype(jnp.int32)
    
    RefFunc->>JAX: jnp.cumsum(sorted_sizes)
    JAX-->>RefFunc: Return cumsum array
    RefFunc->>RefFunc: Cast to int32 with .astype(jnp.int32)
    
    RefFunc->>JAX: jnp.argsort(sorted_indices)
    JAX-->>RefFunc: Return argsort result
    RefFunc->>RefFunc: Cast to int32 with .astype(jnp.int32)
    Note over RefFunc: NEW: Ensures int32 for inverse_indices
    
    RefFunc->>JAX: jnp.arange(num_tokens, dtype=jnp.int32)
    JAX-->>RefFunc: Return int32 array
    Note over JAX,RefFunc: Explicit dtype prevents platform ambiguity
    
    RefFunc->>JAX: jnp.searchsorted(src_cumsum[1:], position_indices)
    JAX-->>RefFunc: Return searchsorted result
    RefFunc->>RefFunc: Cast to int32 with .astype(jnp.int32)
    Note over RefFunc: NEW: Ensures int32 for chunk_ids
    
    RefFunc->>RefFunc: Compute row_id_map using int32 arrays
    RefFunc-->>Test: Return int32 row_id_map
    Note over Test,RefFunc: All operations now use consistent int32 dtype
Loading

@tdophung tdophung changed the title [Draft] Fix bug test perm 2 Really fix 50% comparison mismatch in sort_chunks_by_index Jan 7, 2026
@tdophung tdophung changed the title Really fix 50% comparison mismatch in sort_chunks_by_index Fix 50% comparison mismatch in sort_chunks_by_index (Cont.) Jan 7, 2026
@tdophung tdophung merged commit 08dc786 into NVIDIA:main Jan 7, 2026
23 of 24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants