[Bugfix] Fix num_heads value for simple connector when tp enabled (#12074)

Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
shangmingc 2025-01-20 10:56:43 +08:00 committed by GitHub
parent bbe5f9de7d
commit df450aa567
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -35,6 +35,7 @@ class SimpleConnector(KVConnectorBase):
):
self.config = config.kv_transfer_config
self.tp_size = config.parallel_config.tensor_parallel_size
if self.config.kv_connector == "PyNcclConnector":
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
@ -161,7 +162,7 @@ class SimpleConnector(KVConnectorBase):
end_layer = model_executable.model.end_layer
model_config = model_executable.model.config
num_heads = model_config.num_key_value_heads
num_heads = int(model_config.num_key_value_heads / self.tp_size)
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads
head_size = int(hidden_size / num_attention_heads)