mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 03:35:01 +08:00
[Fix] Remove divisibility requirement between num_kv_heads and tp_size in bailing_moe (#26876)
Signed-off-by: vito.yy <vito.yy@antgroup.com>
This commit is contained in:
parent
5210dc3940
commit
5c3bae1a6a
@ -86,13 +86,12 @@ class BailingAttention(nn.Module):
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
assert self.total_kv_heads % tp_size == 0
|
||||
assert self.total_num_heads >= self.total_kv_heads
|
||||
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads)
|
||||
self.q_size_per_rank = self.head_dim * self.num_heads
|
||||
self.num_kv_heads = self.total_kv_heads // tp_size
|
||||
self.num_kv_heads = max(1, self.total_kv_heads // tp_size)
|
||||
self.kv_size_per_rank = self.num_kv_heads * self.head_dim
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.use_qk_norm = getattr(config, "use_qk_norm", False)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user