Skip to content

Commit 424bcb1

Browse files
committed
logic
Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
1 parent 61d0dd7 commit 424bcb1

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,14 +405,27 @@ def _infer_prefixes_and_module_names(
405405
# 'model-00001-of-00002.safetensors')],
406406
# 'router.weight': [('model.layers.0.block_sparse_moe.router.layer.weight',
407407
# 'model-00001-of-00002.safetensors')]})
408+
ip_op_layers = False
409+
router_layer = False
410+
lora_utils = None
411+
if lora:
412+
for namex, _ in sd.items():
413+
if "w1" in namex:
414+
ip_op_layers = True
415+
break
416+
for namex, _ in sd.items():
417+
if "router" in namex:
418+
router_layer = True
419+
break
420+
lora_utils = [router_layer, ip_op_layers]
408421

409422
checkpoint_metadata = get_checkpoint_meta_from_sharded_safetensor(
410423
weight_map,
411424
prefix,
412425
module_name,
413426
router_name,
414427
expert_name,
415-
lora_utils=lora,
428+
lora_utils=lora_utils,
416429
)
417430

418431
model2scatter = defaultdict(dict)

plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def get_checkpoint_meta_from_sharded_safetensor(
8989
expert_name: str = "experts", # e.g., named "experts" within block_sparse_moe
9090
expert_map: Dict = None, # map -> [w1,w2,w3]
9191
lora_start: bool = False, # if lora is detected in prepare_scattermoe.py
92-
lora_utils: bool = False, # if lora is detected in checkpoint_utils.py
92+
lora_utils: List = None, # if lora is detected in checkpoint_utils.py
9393
target_modules: Dict = None, # target modules from prepare_scattermoe.py
9494
) -> Dict[str, List[Tuple]]:
9595
"""
@@ -171,7 +171,7 @@ def _insert(L: List, i: int, v):
171171
f"'{router_name}' or expert_name '{expert_name}'"
172172
)
173173
if m.group(1) == router_name:
174-
if lora_utils:
174+
if lora_utils[0]:
175175
_map[KEY_SCATTERMOE_LORA_A_ROUTER].append((k, stfile))
176176
_map[KEY_SCATTERMOE_LORA_B_ROUTER].append((k, stfile))
177177
else:
@@ -184,7 +184,7 @@ def _insert(L: List, i: int, v):
184184
index = m.group(2)
185185
index = 0 if index is None else int(index)
186186
mod = None
187-
if not lora_utils:
187+
if not lora_utils[1]:
188188
for mod in expert_map.get(m.group(1), expert_map.get(m.group(3))):
189189
_insert(_map[f"{mod}.weight"], index, (k, stfile))
190190
else:

0 commit comments

Comments
 (0)