mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 07:45:01 +08:00
[Misc] Add type alias ReqId and EngineId for better readability (#19880)
Signed-off-by: Linkun Chen <github@lkchen.net>
This commit is contained in:
parent
61f4fc5dc6
commit
d0132f025d
@ -36,6 +36,8 @@ if TYPE_CHECKING:
|
|||||||
from vllm.v1.request import Request
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
Transfer = tuple[int, float] # (xfer_handle, start_time)
|
Transfer = tuple[int, float] # (xfer_handle, start_time)
|
||||||
|
EngineId = str
|
||||||
|
ReqId = str
|
||||||
GET_META_MSG = b"get_meta_msg"
|
GET_META_MSG = b"get_meta_msg"
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -75,7 +77,7 @@ class ReqMeta:
|
|||||||
class NixlConnectorMetadata(KVConnectorMetadata):
|
class NixlConnectorMetadata(KVConnectorMetadata):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.requests: dict[str, ReqMeta] = {}
|
self.requests: dict[ReqId, ReqMeta] = {}
|
||||||
|
|
||||||
def add_new_req(
|
def add_new_req(
|
||||||
self,
|
self,
|
||||||
@ -96,16 +98,17 @@ class NixlConnector(KVConnectorBase_V1):
|
|||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
|
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
|
||||||
assert vllm_config.kv_transfer_config is not None
|
assert vllm_config.kv_transfer_config is not None
|
||||||
self.engine_id = vllm_config.kv_transfer_config.engine_id
|
assert vllm_config.kv_transfer_config.engine_id is not None
|
||||||
|
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
|
||||||
|
|
||||||
if role == KVConnectorRole.SCHEDULER:
|
if role == KVConnectorRole.SCHEDULER:
|
||||||
self.connector_scheduler : Optional[NixlConnectorScheduler] = \
|
self.connector_scheduler : Optional[NixlConnectorScheduler] = \
|
||||||
NixlConnectorScheduler(vllm_config, str(self.engine_id))
|
NixlConnectorScheduler(vllm_config, self.engine_id)
|
||||||
self.connector_worker: Optional[NixlConnectorWorker] = None
|
self.connector_worker: Optional[NixlConnectorWorker] = None
|
||||||
elif role == KVConnectorRole.WORKER:
|
elif role == KVConnectorRole.WORKER:
|
||||||
self.connector_scheduler = None
|
self.connector_scheduler = None
|
||||||
self.connector_worker = NixlConnectorWorker(
|
self.connector_worker = NixlConnectorWorker(
|
||||||
vllm_config, str(self.engine_id))
|
vllm_config, self.engine_id)
|
||||||
|
|
||||||
############################################################
|
############################################################
|
||||||
# Scheduler Side Methods
|
# Scheduler Side Methods
|
||||||
@ -179,7 +182,7 @@ class NixlConnectorScheduler:
|
|||||||
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
self.block_size = vllm_config.cache_config.block_size
|
self.block_size = vllm_config.cache_config.block_size
|
||||||
self.engine_id = engine_id
|
self.engine_id: EngineId = engine_id
|
||||||
self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
|
self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
|
||||||
self.side_channel_port = (
|
self.side_channel_port = (
|
||||||
envs.VLLM_NIXL_SIDE_CHANNEL_PORT +
|
envs.VLLM_NIXL_SIDE_CHANNEL_PORT +
|
||||||
@ -190,7 +193,7 @@ class NixlConnectorScheduler:
|
|||||||
# Requests that need to start recv.
|
# Requests that need to start recv.
|
||||||
# New requests are added by update_state_after_alloc in
|
# New requests are added by update_state_after_alloc in
|
||||||
# the scheduler. Used to make metadata passed to Worker.
|
# the scheduler. Used to make metadata passed to Worker.
|
||||||
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
|
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
|
||||||
|
|
||||||
def get_num_new_matched_tokens(
|
def get_num_new_matched_tokens(
|
||||||
self, request: "Request",
|
self, request: "Request",
|
||||||
@ -332,19 +335,19 @@ class NixlConnectorWorker:
|
|||||||
# Agent.
|
# Agent.
|
||||||
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
|
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
|
||||||
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
|
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
|
||||||
self._remote_agents: dict[str, dict[int, str]] = defaultdict(dict)
|
self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict)
|
||||||
|
|
||||||
# NIXL handshake port.
|
# NIXL handshake port.
|
||||||
# NOTE(rob): Within a DP group, each DP rank gets its own
|
# NOTE(rob): Within a DP group, each DP rank gets its own
|
||||||
# base port (which is sent in the KVTransferParams).
|
# base port (which is sent in the KVTransferParams).
|
||||||
# Each TP rank listens/queries on the base_port + tp_rank.
|
# Each TP rank listens/queries on the base_port + tp_rank.
|
||||||
self.side_channel_port = (
|
self.side_channel_port: int = (
|
||||||
envs.VLLM_NIXL_SIDE_CHANNEL_PORT +
|
envs.VLLM_NIXL_SIDE_CHANNEL_PORT +
|
||||||
vllm_config.parallel_config.data_parallel_rank_local *
|
vllm_config.parallel_config.data_parallel_rank_local *
|
||||||
vllm_config.parallel_config.tensor_parallel_size)
|
vllm_config.parallel_config.tensor_parallel_size)
|
||||||
|
|
||||||
# Metadata.
|
# Metadata.
|
||||||
self.engine_id = engine_id
|
self.engine_id: EngineId = engine_id
|
||||||
self.tp_rank = get_tensor_model_parallel_rank()
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
self.world_size = get_tensor_model_parallel_world_size()
|
self.world_size = get_tensor_model_parallel_world_size()
|
||||||
self.tp_group = get_tp_group()
|
self.tp_group = get_tp_group()
|
||||||
@ -354,7 +357,7 @@ class NixlConnectorWorker:
|
|||||||
|
|
||||||
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
|
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
|
||||||
# rank will still only pull from a single remote TP worker.
|
# rank will still only pull from a single remote TP worker.
|
||||||
self.kv_caches_base_addr: dict[str, list[int]] = {}
|
self.kv_caches_base_addr: dict[EngineId, list[int]] = {}
|
||||||
|
|
||||||
# Number of NIXL regions. Currently one region per cache
|
# Number of NIXL regions. Currently one region per cache
|
||||||
# (so 1 per layer for MLA, otherwise 2 per layer)
|
# (so 1 per layer for MLA, otherwise 2 per layer)
|
||||||
@ -364,23 +367,23 @@ class NixlConnectorWorker:
|
|||||||
# nixl_prepped_dlist_handle.
|
# nixl_prepped_dlist_handle.
|
||||||
self.src_xfer_side_handle: int = 0
|
self.src_xfer_side_handle: int = 0
|
||||||
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
|
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
|
||||||
self.dst_xfer_side_handles: dict[str, int] = {}
|
self.dst_xfer_side_handles: dict[EngineId, int] = {}
|
||||||
|
|
||||||
# Map of engine_id -> num_blocks. All ranks in the same deployment will
|
# Map of engine_id -> num_blocks. All ranks in the same deployment will
|
||||||
# have the same number of blocks.
|
# have the same number of blocks.
|
||||||
self.dst_num_blocks: dict[str, int] = {}
|
self.dst_num_blocks: dict[EngineId, int] = {}
|
||||||
self._registered_descs: list[Any] = []
|
self._registered_descs: list[Any] = []
|
||||||
|
|
||||||
# In progress transfers.
|
# In progress transfers.
|
||||||
# [req_id -> list[handle]]
|
# [req_id -> list[handle]]
|
||||||
self._recving_transfers = defaultdict[str, list[Transfer]](list)
|
self._recving_transfers = defaultdict[ReqId, list[Transfer]](list)
|
||||||
|
|
||||||
# Complete transfer tracker. Used by the rank 0 to track finished
|
# Complete transfer tracker. Used by the rank 0 to track finished
|
||||||
# transactions on ranks 1 to N-1.
|
# transactions on ranks 1 to N-1.
|
||||||
# [req_id -> count]
|
# [req_id -> count]
|
||||||
self._done_recving_count: defaultdict[str,
|
self._done_recving_count: defaultdict[ReqId,
|
||||||
int] = defaultdict(lambda: 0)
|
int] = defaultdict(lambda: 0)
|
||||||
self._done_sending_count: defaultdict[str,
|
self._done_sending_count: defaultdict[ReqId,
|
||||||
int] = defaultdict(lambda: 0)
|
int] = defaultdict(lambda: 0)
|
||||||
|
|
||||||
# Background thread for establishing new connections.
|
# Background thread for establishing new connections.
|
||||||
@ -408,10 +411,10 @@ class NixlConnectorWorker:
|
|||||||
self._use_flashinfer = attn_backend == _Backend.FLASHINFER_VLLM_V1
|
self._use_flashinfer = attn_backend == _Backend.FLASHINFER_VLLM_V1
|
||||||
logger.debug("Detected attention backend %s", self.backend_name)
|
logger.debug("Detected attention backend %s", self.backend_name)
|
||||||
|
|
||||||
self._tp_size: dict[str, int] = {self.engine_id: self.world_size}
|
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
|
||||||
# With heterogeneous TP, P must wait for all assigned D TP workers to
|
# With heterogeneous TP, P must wait for all assigned D TP workers to
|
||||||
# finish reading before safely freeing the blocks.
|
# finish reading before safely freeing the blocks.
|
||||||
self.consumer_notification_counts_by_req = defaultdict[str, int](int)
|
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
|
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user