Skip to content

[JAX] Fix FSDP when FSDP+EP is active#2649

Open
jberchtold-nvidia wants to merge 4 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/fix-fsdp-when-ep-is-active
Open

[JAX] Fix FSDP when FSDP+EP is active#2649
jberchtold-nvidia wants to merge 4 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/fix-fsdp-when-ep-is-active

Conversation

@jberchtold-nvidia
Copy link
Collaborator

Description

In some models when FSDP and EP are both active, the non-MoE blocks use (('fsdp', 'expert'), None, None, None) with the EP GPU domain acting as FSDP. Our TE/JAX GEMM did not handle this correctly as it assumed fsdp would be present in an axis alone, not as part of a tuple. This resulting in unnecessary AllGather's of the inputs blocking the critical path.

This PR fixes the check and if the TE GEMM sees an inspect sharding with fsdp as part a tuple, it performs FSDP on the GPU domain of all axes specified in the tuple.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Update the check in TE/JAX GEMM to handle ('fsdp', 'expert') along the same axis

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 4, 2026

Greptile Overview

Greptile Summary

  • Adjusts TE/JAX GEMM sharding inference to treat PartitionSpec entries where fsdp_resource appears inside a tuple (e.g. ('fsdp','expert')) as requiring RHS non-contracting dims to be gathered (set to None).
  • Change is localized to GemmPrimitive._parse_operand_output_specs, which feeds partition()/infer_sharding_from_operands() and ultimately drives how JAX places collectives around the custom GEMM.
  • Intended effect is to prevent incorrect/extra input all-gathers on the critical path when FSDP and Expert Parallelism share an axis tuple.

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk.
  • Single, targeted change to a local sharding-spec rewrite; logic is straightforward and consistent with the existing intent of unsharding RHS non-contracting dims along FSDP. No other call sites or invariants appear to be affected.
  • transformer_engine/jax/cpp_extensions/gemm.py: verify behavior with tuple PartitionSpec entries in an integration test if available.

Important Files Changed

Filename Overview
transformer_engine/jax/cpp_extensions/gemm.py Updates RHS non-contracting sharding check to treat tuple specs containing fsdp_resource as needing gather (set spec=None), addressing FSDP+EP tuple PartitionSpec cases.

Sequence Diagram

sequenceDiagram
    participant UserCode as Model/TE JAX caller
    participant JAX as JAX sharding propagation
    participant Gemm as GemmPrimitive._parse_operand_output_specs

    UserCode->>JAX: Invoke TE gemm with sharded operands
    JAX->>Gemm: Provide arg_infos (PartitionSpec per operand dim)
    Gemm->>Gemm: Split specs into contracting/non-contracting
    Gemm->>Gemm: If reduce_spec is None
    Gemm->>Gemm: For each rhs non-contracting dim
    alt spec == fsdp_resource
        Gemm->>Gemm: Set rhs spec to None (gather along FSDP)
    else spec is tuple and contains fsdp_resource
        Gemm->>Gemm: Set rhs spec to None (gather along FSDP tuple axis)
    else other spec
        Gemm->>Gemm: Keep rhs spec unchanged
    end
    Gemm-->>JAX: Return inferred operand/output PartitionSpecs
    JAX-->>UserCode: Compile with corrected sharding (avoid unnecessary all-gathers)
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia force-pushed the jberchtold/fix-fsdp-when-ep-is-active branch from 4ee4d73 to 08fda9c Compare February 4, 2026 17:10
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci L1 jax

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci L1 jax

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant