Skip to content

Commit 26c82db

Browse files
KshitijLakhanipre-commit-ci[bot]Kshitij  Janardan Lakhaniksivaman
authored
[JAX] Fix incorrect calculation of segment pos from segment ids in user-facing API (#2523)
* Fix incorrect calculation of segment pos from segment ids for thd cases and load balanced cases in from_segment_ids_and_pos. Enforce passing of segment_pos for THD cases and lod balanced cases Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Correct the assert condition Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Modify fused attn tests to pass new args to from_segment_ids_and_pos() Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Calculate seg ids before pos Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 1. Change the signature for from_segment_ids_and_pos() 2. Add support for THD in from_segment_ids_and_pos() 3. Assert if load balanced segment_ids is passed to generate a segment_pos Signed-off-by: Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Pass keyword-only args by name Signed-off-by: Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com> * nit: Fix typo to use seg_ids instead of segment_ids Signed-off-by: Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com> * nit: Fix comments Signed-off-by: Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com> * Modify the function call to differentiate between load balancing and actually reordered segment_ids and segment_pos Signed-off-by: Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix the is_segment_ids_reordered to be set only when CP and load balancing Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Fix comments for from_segment_ids_and_pos() Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Code clean up for more information, see https://pre-commit.ci Fix lint errors Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> Signed-off-by: Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com> Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
1 parent 5ba01fa commit 26c82db

File tree

2 files changed

+90
-11
lines changed

2 files changed

+90
-11
lines changed

tests/jax/test_fused_attn.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -668,14 +668,24 @@ def generate_random_segment_ids(
668668
(self.offsets_q, self.offsets_kv),
669669
)
670670
case SeqDescFormat.SegmentIDs:
671+
# Exercise the path to generate the segment_pos in from_segment_ids_and_pos()
672+
# if no CP and load balancing, else explicitly pass the segment_pos
671673
self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
672674
(
673675
self.cp_reorder_fn(self.segment_ids_q),
674676
self.cp_reorder_fn(self.segment_ids_kv),
675677
),
676678
(
677-
self.cp_reorder_fn(self.segment_pos_q),
678-
self.cp_reorder_fn(self.segment_pos_kv),
679+
(
680+
self.cp_reorder_fn(self.segment_pos_q),
681+
self.cp_reorder_fn(self.segment_pos_kv),
682+
)
683+
if self.cp_size > 1 and self.cp_load_balanced
684+
else None
685+
),
686+
is_thd=self.qkv_layout.is_thd(),
687+
is_segment_ids_reordered=(
688+
True if self.cp_size > 1 and self.cp_load_balanced else False
679689
),
680690
)
681691
case _:
@@ -704,6 +714,8 @@ def generate_random_segment_ids(
704714
self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
705715
(self.segment_ids_q, self.segment_ids_kv),
706716
None,
717+
is_thd=self.qkv_layout.is_thd(),
718+
is_segment_ids_reordered=False,
707719
)
708720
case _:
709721
raise ValueError(f"Unknown {self.seq_desc_format=}")

transformer_engine/jax/attention.py

Lines changed: 76 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,7 @@ class SequenceDescriptor:
658658
- SequenceDescriptor.from_seqlens_and_offsets
659659
For THD (packed) cases, where each batch may have not only 1 sequence.
660660
- SequenceDescriptor.from_segment_ids_and_pos
661-
Experimental feature for THD (packed) cases with context parallelism.
661+
Experimental feature for BSHD (with and without reordering) and THD (packed) cases without reordering
662662
"""
663663

664664
seqlens: Optional[Tuple[jnp.ndarray, jnp.ndarray]]
@@ -796,9 +796,14 @@ def from_segment_ids_and_pos(
796796
cls,
797797
segment_ids: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]],
798798
segment_pos: Optional[Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]] = None,
799+
*,
800+
is_thd: bool,
801+
is_segment_ids_reordered: bool,
799802
) -> SequenceDescriptor:
800803
"""
801-
Experimental factory method for inputs with segment IDs and optional positions. (THD)
804+
Experimental factory method for inputs with segment IDs and optional positions.
805+
segment_pos = None to be used only for: BSHD with or without load balancing and,
806+
THD without load balancing
802807
Args:
803808
segment_ids(Tuple(jnp.ndarray, jnp.ndarray)) = (q_segment_ids, kv_segment_ids):
804809
- q_segment_ids (jnp.ndarray):
@@ -812,22 +817,84 @@ def from_segment_ids_and_pos(
812817
The position inside each segment for query, with shape [batch, max_seqlen].
813818
- kv_segment_pos (jnp.ndarray):
814819
The position inside each segment for key, value, with shape [batch, max_seqlen].
820+
is_thd(bool): If True, QKVLayout is of type THD, else it is BSHD
821+
is_segment_ids_reordered(bool): If True, the segment ids have been reordered for load balancing.
822+
Only THD with load balancing is expected to have this flag set to True
815823
Return:
816824
A SequenceDescriptor with segment_ids/segment_pos initialized.
817825
"""
818826
q_seg_ids, kv_seg_ids = cls._expand_to_pair(segment_ids)
819827

820-
if segment_pos is not None:
821-
segment_pos = cls._expand_to_pair(segment_pos)
822-
else:
823-
824-
def generate_default_pos(segment_ids):
825-
seqlen = segment_ids.shape[-1]
826-
return jnp.broadcast_to(jnp.arange(seqlen), segment_ids.shape)
828+
# Using defaults : segment pos has to be generated.
829+
if segment_pos is None:
830+
# THD + load balanced segment_ids are not supported in this function
831+
# BSHD + load balanced segment_ids are incorrect as BSHD handles reordering within the primitive itself
832+
if is_segment_ids_reordered:
833+
assert not is_thd, (
834+
f"{segment_pos=} default arg is not supported for load balanced reordered"
835+
" (Striped) THD inputs. Please pass the load balanced reordered segment_pos"
836+
" and segment_ids explicitly to {from_segment_ids_and_pos.__qualname__}"
837+
" using convenience function reorder_causal_load_balancing()"
838+
)
839+
assert is_thd, (
840+
f"{segment_pos=} default arg is not supported for load balanced reordered (Dual"
841+
" Chunk) BSHD inputs. BSHD segment_pos and segment_ids do not need to be load"
842+
" balanced reordered. The reordering for these is performed within the"
843+
" primitive"
844+
)
845+
846+
# Generate the default pos for THD and BSHD non-reordered segment_ids
847+
def generate_default_pos(seg_ids):
848+
if is_thd:
849+
batch_size, seq_size = seg_ids.shape
850+
# Assume that the first token belongs to a segment and is not a padded token
851+
first_is_segment = jnp.full((batch_size, 1), True, dtype=bool)
852+
# Get segment start positions
853+
segment_start = jnp.concatenate(
854+
[
855+
first_is_segment,
856+
(seg_ids[..., 1:] != seg_ids[..., :-1]) & (seg_ids[..., 1:] != 0),
857+
],
858+
axis=-1,
859+
)
860+
# Get offset for location where new segment starts
861+
segment_start_idx = jax.vmap(lambda row: jnp.arange(row.size) * row)(
862+
segment_start
863+
)
864+
segment_start_offsets = jax.vmap(jnp.maximum.accumulate)(segment_start_idx)
865+
866+
# Get the last non-zero index - after this everything is padding
867+
# (B,)
868+
last_nonzero_idx = jax.vmap(
869+
lambda segids_row: jnp.max(
870+
jnp.where(segids_row != 0, jnp.arange(seq_size), -1)
871+
)
872+
)(seg_ids)
873+
seg_pos_no_thd = jnp.arange(seq_size)
874+
# Get a mask which can be used to zero out all the padding at the end (after the non-zero index)
875+
mask = seg_pos_no_thd <= last_nonzero_idx[:, None]
876+
877+
# Get the unmasked seg_pos for the THD sequence
878+
seg_pos = (
879+
jnp.broadcast_to(jnp.arange(seq_size), seg_ids.shape)
880+
- segment_start_offsets
881+
)
882+
883+
# Use the mask to zero out the padding at the end (after the non-zero index)
884+
segment_pos = jax.vmap(
885+
lambda pos_row, mask_row: jnp.where(mask_row, pos_row, 0)
886+
)(seg_pos, mask)
887+
return segment_pos
888+
889+
seqlen = seg_ids.shape[-1]
890+
return jnp.broadcast_to(jnp.arange(seqlen), seg_ids.shape)
827891

828892
q_seg_pos = generate_default_pos(q_seg_ids)
829893
kv_seg_pos = generate_default_pos(kv_seg_ids)
830894
segment_pos = (q_seg_pos, kv_seg_pos)
895+
# Explicitly passed segment_pos
896+
else:
897+
segment_pos = cls._expand_to_pair(segment_pos)
831898

832899
return cls(
833900
segment_ids=(q_seg_ids, kv_seg_ids),

0 commit comments

Comments
 (0)