Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions src/MaxText/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,12 +741,9 @@ def __call__(
model_mode,
)
if cfg.using_pipeline_parallelism:
if cfg.pipeline_fsdp_ag_once:
logical_partition_spec = self.pipeline_module.get_weight_sharding(
y, decoder_segment_ids, decoder_positions, deterministic, model_mode
)
else:
logical_partition_spec = None # This partition spec is only used for the fsdp_ag_once feature.
logical_partition_spec = self.pipeline_module.get_weight_sharding(
y, decoder_segment_ids, decoder_positions, deterministic, model_mode
)
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek."
dense_layer = RemattedBlockLayers[0]
Expand Down Expand Up @@ -946,6 +943,13 @@ def __call__(

else:
logits = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode)
logits = sharding.maybe_shard_with_logical(
logits,
("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab"),
mesh=self.mesh,
shard_mode=self.config.shard_mode,
debug_sharding=self.config.debug_sharding,
)

# The API of the Decoder is now a tuple, providing both the main output
# and the raw hidden state needed for auxiliary tasks.
Expand Down
2 changes: 1 addition & 1 deletion src/MaxText/layers/deepseek_batchsplit.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,7 @@ def gmm(
input_buffer_count,
combine_scopes,
):
if config.use_qwix_quantization:
if config.use_qwix_quantization or config.using_pipeline_parallelism:
output = megablox.gmm(
lhs=inputs,
rhs=kernel,
Expand Down
Loading
Loading