mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-25 01:24:27 +08:00
[Feature][P/D]: Optimize NIXL Connector xfer Launch (#23887)
Signed-off-by: ycyaw66 <497410282@qq.com> Co-authored-by: ycyaw66 <497410282@qq.com>
This commit is contained in:
parent
a742322092
commit
6adaed42f4
@ -14,6 +14,7 @@ from dataclasses import dataclass
|
|||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import zmq
|
import zmq
|
||||||
|
|
||||||
@ -1191,8 +1192,8 @@ class NixlConnectorWorker:
|
|||||||
# workers will issue xfers to parts of the P worker remote kv caches.
|
# workers will issue xfers to parts of the P worker remote kv caches.
|
||||||
|
|
||||||
# Get descs ids.
|
# Get descs ids.
|
||||||
local_block_descs_ids: list[int] = []
|
local_block_descs_ids: np.ndarray
|
||||||
remote_block_descs_ids: list[int] = []
|
remote_block_descs_ids: np.ndarray
|
||||||
if not self.block_window_per_layer:
|
if not self.block_window_per_layer:
|
||||||
# Default case: assume global attention
|
# Default case: assume global attention
|
||||||
remote_block_descs_ids = self._get_block_descs_ids(
|
remote_block_descs_ids = self._get_block_descs_ids(
|
||||||
@ -1202,6 +1203,8 @@ class NixlConnectorWorker:
|
|||||||
else:
|
else:
|
||||||
# TODO(mgoin): remove this once we have hybrid memory allocator
|
# TODO(mgoin): remove this once we have hybrid memory allocator
|
||||||
# Optimization for models with local attention (Llama 4)
|
# Optimization for models with local attention (Llama 4)
|
||||||
|
local_descs_list = []
|
||||||
|
remote_descs_list = []
|
||||||
for layer_idx, block_window in enumerate(
|
for layer_idx, block_window in enumerate(
|
||||||
self.block_window_per_layer):
|
self.block_window_per_layer):
|
||||||
# For each layer:
|
# For each layer:
|
||||||
@ -1221,8 +1224,11 @@ class NixlConnectorWorker:
|
|||||||
layer_remote_desc_ids = self._get_block_descs_ids(
|
layer_remote_desc_ids = self._get_block_descs_ids(
|
||||||
dst_engine_id, layer_remote_block_ids, layer_idx)
|
dst_engine_id, layer_remote_block_ids, layer_idx)
|
||||||
|
|
||||||
local_block_descs_ids.extend(layer_local_desc_ids)
|
local_descs_list.append(layer_local_desc_ids)
|
||||||
remote_block_descs_ids.extend(layer_remote_desc_ids)
|
remote_descs_list.append(layer_remote_desc_ids)
|
||||||
|
|
||||||
|
local_block_descs_ids = np.concatenate(local_descs_list)
|
||||||
|
remote_block_descs_ids = np.concatenate(remote_descs_list)
|
||||||
|
|
||||||
assert len(local_block_descs_ids) == len(remote_block_descs_ids)
|
assert len(local_block_descs_ids) == len(remote_block_descs_ids)
|
||||||
|
|
||||||
@ -1247,14 +1253,14 @@ class NixlConnectorWorker:
|
|||||||
def _get_block_descs_ids(self,
|
def _get_block_descs_ids(self,
|
||||||
engine_id: str,
|
engine_id: str,
|
||||||
block_ids: list[int],
|
block_ids: list[int],
|
||||||
layer_idx: Optional[int] = None) -> list[int]:
|
layer_idx: Optional[int] = None) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Get the descs ids for a set of block ids.
|
Get the descs ids for a set of block ids.
|
||||||
If layer_idx is provided, we use the region_ids for the given layer.
|
If layer_idx is provided, we use the region_ids for the given layer.
|
||||||
Otherwise, we use all regions.
|
Otherwise, we use all regions.
|
||||||
"""
|
"""
|
||||||
if layer_idx is None:
|
if layer_idx is None:
|
||||||
region_ids = range(self.num_regions)
|
region_ids = np.arange(self.num_regions)
|
||||||
else:
|
else:
|
||||||
assert layer_idx < self.num_layers
|
assert layer_idx < self.num_layers
|
||||||
if self.num_layers < self.num_regions:
|
if self.num_layers < self.num_regions:
|
||||||
@ -1262,20 +1268,19 @@ class NixlConnectorWorker:
|
|||||||
# the regions are organized as [K0, V0, K1, V1, ...]
|
# the regions are organized as [K0, V0, K1, V1, ...]
|
||||||
# and we select K_i and V_i
|
# and we select K_i and V_i
|
||||||
assert 2 * self.num_layers == self.num_regions
|
assert 2 * self.num_layers == self.num_regions
|
||||||
region_ids = range(2 * layer_idx, 2 * layer_idx + 2)
|
region_ids = np.arange(2 * layer_idx, 2 * layer_idx + 2)
|
||||||
else:
|
else:
|
||||||
# Otherwise, we assume we have MLA and select i-th layer
|
# Otherwise, we assume we have MLA and select i-th layer
|
||||||
assert self.num_layers == self.num_regions
|
assert self.num_layers == self.num_regions
|
||||||
region_ids = range(layer_idx, layer_idx + 1)
|
region_ids = np.arange(layer_idx, layer_idx + 1)
|
||||||
|
|
||||||
num_blocks = self.dst_num_blocks[engine_id]
|
num_blocks = self.dst_num_blocks[engine_id]
|
||||||
|
|
||||||
# Compute the desc ids for each block.
|
# Compute the desc ids for each block.
|
||||||
descs_ids: list[int] = []
|
region_ids = region_ids[:, None]
|
||||||
for reg_id in region_ids:
|
block_ids = np.array(block_ids)[None, :]
|
||||||
for block_id in block_ids:
|
descs_ids = region_ids * num_blocks + block_ids
|
||||||
descs_ids.append(reg_id * num_blocks + block_id)
|
return descs_ids.flatten()
|
||||||
return descs_ids
|
|
||||||
|
|
||||||
def get_backend_aware_kv_block_len(self):
|
def get_backend_aware_kv_block_len(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user