mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-22 09:57:03 +08:00
[P/D][NixlConnector] Support tp_size > num_kv_heads deployments (#19691)
Signed-off-by: NickLucche <nlucches@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
f17aec0d63
commit
2ebff5b77c
@ -22,6 +22,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
|||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
|
||||||
get_tp_group)
|
get_tp_group)
|
||||||
|
from vllm.distributed.utils import divide
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import _Backend
|
from vllm.platforms import _Backend
|
||||||
from vllm.utils import make_zmq_path, make_zmq_socket, round_down
|
from vllm.utils import make_zmq_path, make_zmq_socket, round_down
|
||||||
@ -679,11 +680,15 @@ class NixlConnectorWorker:
|
|||||||
|
|
||||||
# Number of D TP workers reading from a single P TP worker. This is
|
# Number of D TP workers reading from a single P TP worker. This is
|
||||||
# 1 when P and D `--tensor-parallel-size` match.
|
# 1 when P and D `--tensor-parallel-size` match.
|
||||||
assert self._tp_size[self.engine_id] % self._tp_size[engine_id] == 0, (
|
tp_ratio = divide(self._tp_size[self.engine_id],
|
||||||
"Local TP size must be divisible by remote TP size.")
|
self._tp_size[engine_id])
|
||||||
tp_ratio = self._tp_size[self.engine_id] // self._tp_size[engine_id]
|
|
||||||
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP"
|
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP"
|
||||||
if self.use_mla:
|
|
||||||
|
# Handle tp_size>num_kv_heads: replicate KV cache.
|
||||||
|
total_num_kv_heads = self.model_config.get_total_num_kv_heads()
|
||||||
|
is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1
|
||||||
|
|
||||||
|
if self.use_mla or is_kv_replicated:
|
||||||
# With MLA the only difference is in the number of blocks.
|
# With MLA the only difference is in the number of blocks.
|
||||||
remote_block_size = nixl_agent_meta.block_len // (
|
remote_block_size = nixl_agent_meta.block_len // (
|
||||||
self.slot_size_bytes)
|
self.slot_size_bytes)
|
||||||
@ -720,7 +725,7 @@ class NixlConnectorWorker:
|
|||||||
self.kv_caches_base_addr[
|
self.kv_caches_base_addr[
|
||||||
engine_id] = nixl_agent_meta.kv_caches_base_addr
|
engine_id] = nixl_agent_meta.kv_caches_base_addr
|
||||||
rank_offset = self.tp_rank % tp_ratio * self.block_len \
|
rank_offset = self.tp_rank % tp_ratio * self.block_len \
|
||||||
if not self.use_mla else 0
|
if not (self.use_mla or is_kv_replicated) else 0
|
||||||
# Register all remote blocks, but only the corresponding kv heads.
|
# Register all remote blocks, but only the corresponding kv heads.
|
||||||
for base_addr in nixl_agent_meta.kv_caches_base_addr:
|
for base_addr in nixl_agent_meta.kv_caches_base_addr:
|
||||||
for block_id in range(nixl_agent_meta.num_blocks):
|
for block_id in range(nixl_agent_meta.num_blocks):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user