Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 55 additions & 12 deletions src/MaxText/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@
from maxtext.utils import max_logging
from maxtext.utils import max_utils

import inspect # for debugging only
from pathlib import Path

_LOGGED_ACTIVATION_SHARDINGS = set()
_LOGGED_LOGICAL_AXES = set()
_ACTIVATION_SHARDINGS_DUMP = []


def get_input_data_sharding(config, mesh):
Expand All @@ -45,51 +47,92 @@ def get_input_data_sharding(config, mesh):
return data_sharding


def maybe_shard_with_name(inputs, named_sharding, shard_mode, debug_sharding=False, extra_stack_level=0):
def _get_sharding_desc(inputs, extra_stack_level):
"""Get the inputs sharding description using inspect module"""
frame = inspect.currentframe()
# Traverse back extra_stack_level times:
for _ in range(1 + extra_stack_level):
if frame is not None:
frame = frame.f_back
if frame is not None:
callers_local_vars = frame.f_locals.items()

x = [var_name for var_name, var_val in callers_local_vars if var_val is inputs]
if len(x) > 0:
caller_path_full = inspect.stack()[1 + extra_stack_level].filename
# Use pathlib.Path to easily extract just the filename from the full path.
caller_filename = Path(caller_path_full).name
return f"{caller_filename[:-3]}/{x[0]}"
return "Unknown"


def maybe_shard_with_name(
inputs, named_sharding, shard_mode, debug_sharding=False, extra_stack_level=0, sharding_desc="", logical_axes=None
):
"""
In auto shardmode, this function hints inputs follow given named_sharding.
In explicit shardmode, this function enforces inputs following named_sharding.
sharding_desc is description of inputs of upper layer(s) of caller (with the form of <filename>/<variable>).
It is used as key in log/dump files when debug_sharding==true
"""
if inputs is None:
return None
if (
debug_sharding and isinstance(inputs, Tracer) and isinstance(named_sharding, NamedSharding)
): # only print pspec for JitTracer
if not sharding_desc:
sharding_desc = _get_sharding_desc(inputs, extra_stack_level + 1)

if not logical_axes:
logical_axes = "Unknown"
elif isinstance(logical_axes, list):
logical_axes = tuple(logical_axes)

pspec = remove_size_one_mesh_axis(getattr(named_sharding, "spec"), getattr(named_sharding, "mesh"))
log_key = (str(jax.typeof(inputs)), tuple(pspec), extra_stack_level)
log_key = (sharding_desc, str(jax.typeof(inputs)), tuple(pspec), extra_stack_level)
if log_key not in _LOGGED_ACTIVATION_SHARDINGS:
max_logging.info(f"Physical: {log_key[0]:.<80} {log_key[1]}.", stacklevel=3 + extra_stack_level)
max_logging.info(f"{sharding_desc} Logical: {log_key[1]:.<60} {logical_axes}.", stacklevel=3 + extra_stack_level)
max_logging.info(f"{sharding_desc} Physical: {log_key[1]:.<60} {log_key[2]}.", stacklevel=3 + extra_stack_level)
_LOGGED_ACTIVATION_SHARDINGS.add(log_key)

_ACTIVATION_SHARDINGS_DUMP.append(
{
f"{sharding_desc}: {log_key[1]}": {
"logic_axes": f"{logical_axes}",
"PartitionSpec": f"P{log_key[2]}",
}
}
)
if shard_mode == ShardMode.EXPLICIT:
return reshard(inputs, named_sharding)
else:
return jax.lax.with_sharding_constraint(inputs, named_sharding)


def maybe_shard_with_logical(
inputs, logical_axes, mesh, shard_mode, rules=None, debug_sharding=False, extra_stack_level=0
inputs, logical_axes, mesh, shard_mode, rules=None, debug_sharding=False, extra_stack_level=0, sharding_desc=""
):
"""
A wrapper of maybe_shard_with_name when logical axes are inputs
sharding_desc is description of inputs of upper layer(s) of caller (with the form of <filename>/<variable>).
It is used as key in log/dump files when debug_sharding==true
"""
if inputs is None:
return None

named_sharding = create_sharding(mesh, logical_axes, rules=rules)

if debug_sharding and isinstance(inputs, Tracer):
log_key = (str(jax.typeof(inputs)), tuple(logical_axes), extra_stack_level)
if debug_sharding and not sharding_desc:
sharding_desc = _get_sharding_desc(inputs, extra_stack_level + 1)

if log_key not in _LOGGED_LOGICAL_AXES:
max_logging.info(f"Logical: {log_key[0]:.<60} {log_key[1]}", stacklevel=3 + extra_stack_level)
_LOGGED_LOGICAL_AXES.add(log_key)
named_sharding = create_sharding(mesh, logical_axes, rules=rules)

return maybe_shard_with_name(
inputs,
named_sharding,
shard_mode,
debug_sharding=debug_sharding,
extra_stack_level=extra_stack_level + 1,
sharding_desc=sharding_desc,
logical_axes=logical_axes,
)


Expand Down
5 changes: 4 additions & 1 deletion src/MaxText/vocabulary_tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@ def vocab_tiling_linen_loss(
)

_maybe_shard_with_name = functools.partial(
maybe_shard_with_name, shard_mode=config.shard_mode, debug_sharding=config.debug_sharding
maybe_shard_with_name,
shard_mode=config.shard_mode,
debug_sharding=config.debug_sharding,
extra_stack_level=1,
)

def _reshape(inputs, out_shape, out_sharding):
Expand Down
1 change: 1 addition & 0 deletions src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ def _maybe_shard_with_logical(self, inputs, logical_name):
mesh=self.mesh,
shard_mode=self.config.shard_mode,
debug_sharding=self.config.debug_sharding,
extra_stack_level=1,
)

def _logical_to_mesh_axes(self, logical_name):
Expand Down
1 change: 1 addition & 0 deletions src/maxtext/layers/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def _maybe_shard_with_logical(self, inputs, logical_axes):
mesh=self.mesh,
rules=self.config.logical_axis_rules,
debug_sharding=self.config.debug_sharding,
extra_stack_level=1,
)

def _maybe_shard_with_name(self, inputs, sharding_name):
Expand Down
15 changes: 9 additions & 6 deletions src/maxtext/models/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,16 +138,20 @@ def with_logical_constraint(self, x):
mesh=self.mesh,
shard_mode=self.config.shard_mode,
debug_sharding=self.config.debug_sharding,
extra_stack_level=1,
)

def dropout_op(self, x, deterministic):
return self.with_logical_constraint(self.dropout(x, deterministic=deterministic))
dropout = self.dropout(x, deterministic=deterministic)
return self.with_logical_constraint(dropout)

def pre_attention_norm_op(self, x):
return self.with_logical_constraint(self.pre_self_attention_layer_norm(x))
pre_attention_norm = self.pre_self_attention_layer_norm(x)
return self.with_logical_constraint(pre_attention_norm)

def post_attention_norm_op(self, x):
return self.with_logical_constraint(self.post_self_attention_layer_norm(x))
post_attention_norm = self.post_self_attention_layer_norm(x)
return self.with_logical_constraint(post_attention_norm)

def attention_op(
self,
Expand Down Expand Up @@ -280,9 +284,8 @@ def __init__(
)

def mlp_op(self, x, deterministic):
return self.with_logical_constraint(
self.mlp(x, deterministic, intermediate_sharding=self.mlp_intermediate_sharding, out_sharding=self.out_sharding)
)
mlp = self.mlp(x, deterministic, intermediate_sharding=self.mlp_intermediate_sharding, out_sharding=self.out_sharding)
return self.with_logical_constraint(mlp)

def __call__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/maxtext/models/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def __init__(
mesh=self.mesh,
shard_mode=config.shard_mode,
debug_sharding=config.debug_sharding,
extra_stack_level=1,
)

def __call__(
Expand Down
Loading
Loading