Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,16 @@ def _create_weight(self):
)
return

def _get_param_slicer(self, sub_child_index: int):
"""
在部分子类场景中,可能需要不同的切片器,比如qkv场景
这里提供一个接口,子类可以重写,这样不同的组成部分可以使用不同的切片器
例如 QKVROWNMMWeight,它的q和kv使用不同的切片器
当然,大部分场景下,都是返回同一个切片器
sub_child_index: 用于区分是第几个weight, 方便子类重写时使用
"""
return self.param_slicer

# 执行顺序
def _load_weight(
self, param_name: Union[str, List[str]], weights: Dict[str, torch.Tensor], sub_child_index: int
Expand All @@ -113,15 +123,17 @@ def _load_weight(
if quanted_param_name in weights:
param_name = quanted_param_name
if param_name in weights:
weight = self.param_slicer._slice_weight(weights[param_name])
slicer = self._get_param_slicer(sub_child_index)
weight = slicer._slice_weight(weights[param_name])
self.quant_method.load_weight(weight, self.mm_param_list[sub_child_index])
return

def _load_bias(
self, param_name: Union[str, List[str]], weights: Dict[str, torch.Tensor], sub_child_index: int
) -> None:
if param_name in weights:
bias = self.param_slicer._slice_bias(weights[param_name])
slicer = self._get_param_slicer(sub_child_index)
bias = slicer._slice_bias(weights[param_name])
self.bias_list[sub_child_index].copy_(bias)
self.bias_list[sub_child_index].load_ok = True
return
Expand All @@ -130,15 +142,17 @@ def _load_weight_scale(
self, param_name: Union[str, List[str]], weights: Dict[str, torch.Tensor], sub_child_index: int
) -> None:
if param_name in weights:
weight_scale = self.param_slicer._slice_weight_scale(weights[param_name])
slicer = self._get_param_slicer(sub_child_index)
weight_scale = slicer._slice_weight_scale(weights[param_name])
self.quant_method.load_weight_scale(weight_scale, self.mm_param_list[sub_child_index])
return

def _load_weight_zero_point(
self, param_name: Union[str, List[str]], weights: Dict[str, torch.Tensor], sub_child_index: int
) -> None:
if param_name in weights:
weight_zero_point = self.param_slicer._slice_weight_zero_point(weights[param_name])
slicer = self._get_param_slicer(sub_child_index)
weight_zero_point = slicer._slice_weight_zero_point(weights[param_name])
self.quant_method.load_weight_zero_point(weight_zero_point, self.mm_param_list[sub_child_index])
return

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,18 @@ def __init__(
) -> None:
self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp()
self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size()
self.repeat_times = 1
self.q_repeat_times = 1
self.kv_repeat_times = 1
assert q_head_num % self.tp_world_size_ == 0, (
f"q_head_num must be divisible by tp_world_size_, " f"but found: {q_head_num} % {self.tp_world_size_}"
)
assert kv_head_num % self.tp_world_size_ == 0, (
f"kv_head_num must be divisible by tp_world_size_" f"but found: {kv_head_num} % {self.tp_world_size_}"
assert kv_head_num % self.tp_world_size_ == 0 or self.tp_world_size_ % kv_head_num == 0, (
f"kv_head_num must be divisible by tp_world_size_ or "
f"tp_world_size_ must be divisible by kv_head_num, "
f"but found: {kv_head_num} % {self.tp_world_size_}"
)
q_hidden_size = (q_head_num // self.tp_world_size_) * head_dim
kv_hidden_size = (kv_head_num // self.tp_world_size_) * head_dim
kv_hidden_size = self._get_tp_padded_head_num(kv_head_num) * head_dim
out_dims = [q_hidden_size, kv_hidden_size, kv_hidden_size]
super().__init__(
in_dim=in_dim,
Expand All @@ -128,13 +131,45 @@ def __init__(
tp_rank=self.tp_rank_,
tp_world_size=self.tp_world_size_,
)
self.param_slicer = get_row_slice_mixin(
self.q_param_slicer = get_row_slice_mixin(
self.quant_method.method_name,
tp_rank=self.tp_rank_,
tp_world_size=self.tp_world_size_,
repeat_times=self.repeat_times,
repeat_times=self.q_repeat_times,
)
self.kv_param_slicer = get_row_slice_mixin(
self.quant_method.method_name,
tp_rank=self.tp_rank_,
tp_world_size=self.tp_world_size_,
repeat_times=self.kv_repeat_times,
)

def _get_param_slicer(self, sub_child_index: int):
"""
sub_child_index:
0 -> q
1 -> k
2 -> v
q 使用 q_param_slicer, k / v 使用 kv_param_slicer.
"""
if sub_child_index == 0:
return self.q_param_slicer
else:
return self.kv_param_slicer

def _get_tp_padded_head_num(self, head_num: int):
if head_num % self.tp_world_size_ == 0:
return head_num // self.tp_world_size_
elif self.tp_world_size_ % head_num == 0:
self.kv_repeat_times = self.tp_world_size_ // head_num
return self.kv_repeat_times * head_num // self.tp_world_size_
else:
raise ValueError(
f"head_num must be divisible by tp_world_size_ or "
f"tp_world_size_ must be divisible by head_num, "
f"but found: {head_num} % {self.tp_world_size_}"
)


class ROWBMMWeight(BMMWeightTpl):
def __init__(
Expand Down