mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 09:06:19 +08:00
[KVConnector] Keep KVTransferParams as a dict (#18033)
This commit is contained in:
parent
d066e52013
commit
59dd311cf5
@ -1,13 +1,11 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
|
||||||
ModelConfig, SchedulerConfig, VllmConfig)
|
ModelConfig, SchedulerConfig, VllmConfig)
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
|
||||||
NixlKVTransferParams)
|
|
||||||
from vllm.v1.core.sched.scheduler import Scheduler
|
from vllm.v1.core.sched.scheduler import Scheduler
|
||||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||||
KVCacheGroupSpec)
|
KVCacheGroupSpec)
|
||||||
@ -124,20 +122,20 @@ def create_request(
|
|||||||
) -> Request:
|
) -> Request:
|
||||||
"""Make dummy request for testing."""
|
"""Make dummy request for testing."""
|
||||||
|
|
||||||
|
kv_transfer_params: Optional[dict[str, Any]] = None
|
||||||
|
|
||||||
if do_remote_decode:
|
if do_remote_decode:
|
||||||
assert not do_remote_prefill
|
assert not do_remote_prefill
|
||||||
kv_transfer_params = NixlKVTransferParams(do_remote_prefill=False,
|
kv_transfer_params = dict(do_remote_prefill=False,
|
||||||
do_remote_decode=True)
|
do_remote_decode=True)
|
||||||
elif do_remote_prefill:
|
elif do_remote_prefill:
|
||||||
kv_transfer_params = NixlKVTransferParams(
|
kv_transfer_params = dict(do_remote_prefill=True,
|
||||||
do_remote_prefill=True,
|
|
||||||
do_remote_decode=False,
|
do_remote_decode=False,
|
||||||
remote_engine_id="my-engine-id",
|
remote_engine_id="my-engine-id",
|
||||||
remote_block_ids=list(range(num_remote_blocks)),
|
remote_block_ids=list(
|
||||||
|
range(num_remote_blocks)),
|
||||||
remote_host="my-host",
|
remote_host="my-host",
|
||||||
remote_port=1234)
|
remote_port=1234)
|
||||||
else:
|
|
||||||
kv_transfer_params = None
|
|
||||||
|
|
||||||
max_tokens = 1 if do_remote_decode else max_tokens
|
max_tokens = 1 if do_remote_decode else max_tokens
|
||||||
sampling_params = SamplingParams(max_tokens=max_tokens)
|
sampling_params = SamplingParams(max_tokens=max_tokens)
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
KVConnectorBase_V1, KVConnectorRole, KVTransferParams)
|
KVConnectorBase_V1, KVConnectorRole)
|
||||||
|
|
||||||
__all__ = ["KVConnectorRole", "KVConnectorBase_V1", "KVTransferParams"]
|
__all__ = ["KVConnectorRole", "KVConnectorBase_V1"]
|
||||||
|
|||||||
@ -48,23 +48,6 @@ class KVConnectorRole(enum.Enum):
|
|||||||
WORKER = 1
|
WORKER = 1
|
||||||
|
|
||||||
|
|
||||||
class KVTransferParams:
|
|
||||||
"""
|
|
||||||
Abstract KVTransferParams used to send KVTransfer
|
|
||||||
parameters between instances of vLLM.
|
|
||||||
|
|
||||||
Specific instances of KVConnector customize this
|
|
||||||
method for serializing / deserializing msgs sent
|
|
||||||
via the HTTP protocol.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_raw_dict(
|
|
||||||
raw_dict: Optional[dict[str,
|
|
||||||
Any]]) -> Optional["KVTransferParams"]:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class KVConnectorMetadata:
|
class KVConnectorMetadata:
|
||||||
"""
|
"""
|
||||||
@ -75,7 +58,6 @@ class KVConnectorMetadata:
|
|||||||
|
|
||||||
|
|
||||||
class KVConnectorBase_V1(ABC):
|
class KVConnectorBase_V1(ABC):
|
||||||
_KVTransferParams = KVTransferParams
|
|
||||||
|
|
||||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -213,13 +195,6 @@ class KVConnectorBase_V1(ABC):
|
|||||||
# Scheduler-side methods
|
# Scheduler-side methods
|
||||||
# ==============================
|
# ==============================
|
||||||
|
|
||||||
def set_kv_transfer_params(self, request: "Request"):
|
|
||||||
"""Parse raw KV Transfer params."""
|
|
||||||
assert request.kv_transfer_params is None
|
|
||||||
kv_transfer_params = self._KVTransferParams.from_raw_dict(
|
|
||||||
request.raw_kv_transfer_params)
|
|
||||||
request.kv_transfer_params = kv_transfer_params
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_num_new_matched_tokens(
|
def get_num_new_matched_tokens(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -16,7 +16,7 @@ import zmq
|
|||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, KVTransferParams)
|
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
|
||||||
get_tp_group)
|
get_tp_group)
|
||||||
@ -44,56 +44,6 @@ except ImportError:
|
|||||||
NixlWrapper = None
|
NixlWrapper = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class NixlKVTransferParams(KVTransferParams):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
do_remote_prefill: bool,
|
|
||||||
do_remote_decode: bool,
|
|
||||||
remote_block_ids: Optional[list[int]] = None,
|
|
||||||
remote_host: Optional[str] = None,
|
|
||||||
remote_port: Optional[int] = None,
|
|
||||||
remote_engine_id: Optional[str] = None,
|
|
||||||
):
|
|
||||||
self.do_remote_prefill = do_remote_prefill
|
|
||||||
self.do_remote_decode = do_remote_decode
|
|
||||||
self.remote_block_ids = remote_block_ids
|
|
||||||
self.remote_host = remote_host
|
|
||||||
self.remote_port = remote_port
|
|
||||||
self.remote_engine_id = remote_engine_id
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_raw_dict(
|
|
||||||
raw_dict: Optional[dict[str,
|
|
||||||
Any]]) -> Optional["NixlKVTransferParams"]:
|
|
||||||
|
|
||||||
# If no raw transfer params passed, return None.
|
|
||||||
if raw_dict is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Validate the request is formatted properly.
|
|
||||||
if (("do_remote_prefill" not in raw_dict)
|
|
||||||
or ("do_remote_decode" not in raw_dict)
|
|
||||||
or ("remote_block_ids" not in raw_dict)
|
|
||||||
or ("remote_host" not in raw_dict)
|
|
||||||
or ("remote_port" not in raw_dict)
|
|
||||||
or ("remote_engine_id" not in raw_dict)):
|
|
||||||
logger.warning(
|
|
||||||
"Got invalid KVTransferParams: %s. This "
|
|
||||||
"request will not utilize KVTransfer", raw_dict)
|
|
||||||
return None
|
|
||||||
|
|
||||||
return NixlKVTransferParams(
|
|
||||||
do_remote_prefill=raw_dict["do_remote_prefill"],
|
|
||||||
do_remote_decode=raw_dict["do_remote_decode"],
|
|
||||||
remote_block_ids=raw_dict["remote_block_ids"],
|
|
||||||
remote_host=raw_dict["remote_host"],
|
|
||||||
remote_port=raw_dict["remote_port"],
|
|
||||||
remote_engine_id=raw_dict["remote_engine_id"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class NixlAgentMetadata(
|
class NixlAgentMetadata(
|
||||||
msgspec.Struct,
|
msgspec.Struct,
|
||||||
omit_defaults=True, # type: ignore[call-arg]
|
omit_defaults=True, # type: ignore[call-arg]
|
||||||
@ -123,25 +73,18 @@ class NixlConnectorMetadata(KVConnectorMetadata):
|
|||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
local_block_ids: list[int],
|
local_block_ids: list[int],
|
||||||
kv_transfer_params: NixlKVTransferParams,
|
kv_transfer_params: dict[str, Any],
|
||||||
):
|
):
|
||||||
assert request_id not in self.requests
|
|
||||||
assert kv_transfer_params.remote_block_ids is not None
|
|
||||||
assert kv_transfer_params.remote_engine_id is not None
|
|
||||||
assert kv_transfer_params.remote_host is not None
|
|
||||||
assert kv_transfer_params.remote_port is not None
|
|
||||||
|
|
||||||
self.requests[request_id] = ReqMeta(
|
self.requests[request_id] = ReqMeta(
|
||||||
local_block_ids=local_block_ids,
|
local_block_ids=local_block_ids,
|
||||||
remote_block_ids=kv_transfer_params.remote_block_ids,
|
remote_block_ids=kv_transfer_params["remote_block_ids"],
|
||||||
remote_engine_id=kv_transfer_params.remote_engine_id,
|
remote_engine_id=kv_transfer_params["remote_engine_id"],
|
||||||
remote_host=kv_transfer_params.remote_host,
|
remote_host=kv_transfer_params["remote_host"],
|
||||||
remote_port=kv_transfer_params.remote_port,
|
remote_port=kv_transfer_params["remote_port"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class NixlConnector(KVConnectorBase_V1):
|
class NixlConnector(KVConnectorBase_V1):
|
||||||
_KVTransferParams: type[NixlKVTransferParams] = NixlKVTransferParams
|
|
||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
|
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
|
||||||
assert vllm_config.kv_transfer_config is not None
|
assert vllm_config.kv_transfer_config is not None
|
||||||
@ -253,52 +196,52 @@ class NixlConnectorScheduler:
|
|||||||
asynchronously (between scheduler steps).
|
asynchronously (between scheduler steps).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
params = request.kv_transfer_params
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"NIXLConnector get_num_new_matched_tokens: "
|
"NIXLConnector get_num_new_matched_tokens: "
|
||||||
"num_computed_tokens=%s, kv_transfer_params=%s",
|
"num_computed_tokens=%s, kv_transfer_params=%s",
|
||||||
num_computed_tokens, request.kv_transfer_params)
|
num_computed_tokens, params)
|
||||||
|
|
||||||
# No KVTransfer for this request.
|
|
||||||
if request.kv_transfer_params is None:
|
|
||||||
return 0, False
|
|
||||||
assert isinstance(request.kv_transfer_params, NixlKVTransferParams)
|
|
||||||
|
|
||||||
|
if params is not None and params.get("do_remote_prefill"):
|
||||||
# Remote prefill: get all prompt blocks from remote.
|
# Remote prefill: get all prompt blocks from remote.
|
||||||
if request.kv_transfer_params.do_remote_prefill:
|
|
||||||
assert num_computed_tokens % self.block_size == 0
|
assert num_computed_tokens % self.block_size == 0
|
||||||
rounded_num_prompt_tokens = round_down(
|
rounded_num_prompt_tokens = round_down(
|
||||||
len(request.prompt_token_ids), self.block_size)
|
len(request.prompt_token_ids), self.block_size)
|
||||||
count = max(rounded_num_prompt_tokens - num_computed_tokens, 0)
|
count = max(rounded_num_prompt_tokens - num_computed_tokens, 0)
|
||||||
return count, count > 0
|
return count, count > 0
|
||||||
|
|
||||||
|
# No remote prefill for this request.
|
||||||
return 0, False
|
return 0, False
|
||||||
|
|
||||||
def update_state_after_alloc(self, request: "Request",
|
def update_state_after_alloc(self, request: "Request",
|
||||||
blocks: "KVCacheBlocks",
|
blocks: "KVCacheBlocks",
|
||||||
num_external_tokens: int):
|
num_external_tokens: int):
|
||||||
|
|
||||||
|
params = request.kv_transfer_params
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"NIXLConnector update_state_after_alloc: "
|
"NIXLConnector update_state_after_alloc: "
|
||||||
"num_external_tokens=%s, kv_transfer_params=%s",
|
"num_external_tokens=%s, kv_transfer_params=%s",
|
||||||
num_external_tokens, request.kv_transfer_params)
|
num_external_tokens, params)
|
||||||
|
|
||||||
if request.kv_transfer_params is None:
|
if params is not None and params.get("do_remote_prefill"):
|
||||||
return
|
|
||||||
|
|
||||||
assert isinstance(request.kv_transfer_params, NixlKVTransferParams)
|
|
||||||
if request.kv_transfer_params.do_remote_prefill:
|
|
||||||
# NOTE(rob): if prompt < block_size, no remote blocks
|
# NOTE(rob): if prompt < block_size, no remote blocks
|
||||||
# since the remote only sends fully computed blocks, so
|
# since the remote only sends fully computed blocks, so
|
||||||
# skip recving for this request. num_external_tokens
|
# skip recving for this request. num_external_tokens
|
||||||
# should be 0 if there are no remote blocks.
|
# should be 0 if there are no remote blocks.
|
||||||
if request.kv_transfer_params.remote_block_ids:
|
if params.get("remote_block_ids"):
|
||||||
|
if all(p in params for p in ("remote_engine_id", "remote_host",
|
||||||
|
"remote_port")):
|
||||||
# Get unhashed blocks to pull from remote.
|
# Get unhashed blocks to pull from remote.
|
||||||
self._reqs_need_recv[request.request_id] = (
|
self._reqs_need_recv[request.request_id] = (
|
||||||
request, blocks.get_unhashed_block_ids())
|
request, blocks.get_unhashed_block_ids())
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Got invalid KVTransferParams: %s. This "
|
||||||
|
"request will not utilize KVTransfer", params)
|
||||||
else:
|
else:
|
||||||
assert num_external_tokens == 0
|
assert num_external_tokens == 0
|
||||||
# Only trigger 1 KV transfer per request.
|
# Only trigger 1 KV transfer per request.
|
||||||
request.kv_transfer_params.do_remote_prefill = False
|
params["do_remote_prefill"] = False
|
||||||
|
|
||||||
def build_connector_meta(
|
def build_connector_meta(
|
||||||
self,
|
self,
|
||||||
@ -308,7 +251,7 @@ class NixlConnectorScheduler:
|
|||||||
|
|
||||||
# Loop through scheduled reqs and convert to ReqMeta.
|
# Loop through scheduled reqs and convert to ReqMeta.
|
||||||
for req_id, (req, block_ids) in self._reqs_need_recv.items():
|
for req_id, (req, block_ids) in self._reqs_need_recv.items():
|
||||||
assert isinstance(req.kv_transfer_params, NixlKVTransferParams)
|
assert req.kv_transfer_params is not None
|
||||||
meta.add_new_req(
|
meta.add_new_req(
|
||||||
request_id=req_id,
|
request_id=req_id,
|
||||||
local_block_ids=block_ids,
|
local_block_ids=block_ids,
|
||||||
@ -330,34 +273,30 @@ class NixlConnectorScheduler:
|
|||||||
should be freed now or will be sent asynchronously and freed later.
|
should be freed now or will be sent asynchronously and freed later.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
params = request.kv_transfer_params
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"NIXLConnector request_finished, "
|
"NIXLConnector request_finished, request_status=%s, "
|
||||||
"request_status=%s, kv_transfer_params=%s", request.status,
|
"kv_transfer_params=%s", request.status, params)
|
||||||
request.kv_transfer_params)
|
|
||||||
|
|
||||||
if request.kv_transfer_params is None:
|
if (params is None or not params.get("do_remote_decode")
|
||||||
return False, None
|
or request.status != RequestStatus.FINISHED_LENGTH_CAPPED):
|
||||||
assert isinstance(request.kv_transfer_params, NixlKVTransferParams)
|
|
||||||
|
|
||||||
if ((not request.kv_transfer_params.do_remote_decode)
|
|
||||||
or (request.status != RequestStatus.FINISHED_LENGTH_CAPPED)):
|
|
||||||
return False, None
|
return False, None
|
||||||
|
|
||||||
# Get computed blocks.
|
# Get computed blocks.
|
||||||
all_full = request.num_computed_tokens % self.block_size == 0
|
all_full = request.num_computed_tokens % self.block_size == 0
|
||||||
computed_block_ids = (block_ids if all_full else block_ids[:-1])
|
computed_block_ids = block_ids if all_full else block_ids[:-1]
|
||||||
|
|
||||||
# If prompt < block_size, no xfer so free blocks immediately.
|
# If prompt < block_size, no xfer so free blocks immediately.
|
||||||
delay_free_blocks = len(computed_block_ids) > 0
|
delay_free_blocks = len(computed_block_ids) > 0
|
||||||
|
|
||||||
return delay_free_blocks, NixlKVTransferParams(
|
return delay_free_blocks, dict(
|
||||||
do_remote_prefill=True,
|
do_remote_prefill=True,
|
||||||
do_remote_decode=False,
|
do_remote_decode=False,
|
||||||
remote_block_ids=computed_block_ids,
|
remote_block_ids=computed_block_ids,
|
||||||
remote_engine_id=self.engine_id,
|
remote_engine_id=self.engine_id,
|
||||||
remote_host=envs.VLLM_NIXL_SIDE_CHANNEL_HOST,
|
remote_host=envs.VLLM_NIXL_SIDE_CHANNEL_HOST,
|
||||||
remote_port=envs.VLLM_NIXL_SIDE_CHANNEL_PORT,
|
remote_port=envs.VLLM_NIXL_SIDE_CHANNEL_PORT,
|
||||||
).__dict__
|
)
|
||||||
|
|
||||||
|
|
||||||
class NixlConnectorWorker:
|
class NixlConnectorWorker:
|
||||||
|
|||||||
@ -12,8 +12,7 @@ from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
|
|||||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||||
KVConnectorFactory)
|
KVConnectorFactory)
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
|
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
|
||||||
KVConnectorRole,
|
KVConnectorRole)
|
||||||
KVTransferParams)
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||||
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
|
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
|
||||||
@ -931,8 +930,13 @@ class Scheduler(SchedulerInterface):
|
|||||||
return self.connector
|
return self.connector
|
||||||
|
|
||||||
def _connector_finished(
|
def _connector_finished(
|
||||||
self, request: Request) -> tuple[bool, Optional[KVTransferParams]]:
|
self, request: Request) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||||
"""Invoke the KV connector request_finished() method if applicable."""
|
"""
|
||||||
|
Invoke the KV connector request_finished() method if applicable.
|
||||||
|
|
||||||
|
Returns optional kv transfer parameters to be included with the
|
||||||
|
request outputs.
|
||||||
|
"""
|
||||||
if self.connector is None:
|
if self.connector is None:
|
||||||
return False, None
|
return False, None
|
||||||
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)
|
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)
|
||||||
|
|||||||
@ -182,13 +182,9 @@ class EngineCore:
|
|||||||
# Start grammar compilation asynchronously
|
# Start grammar compilation asynchronously
|
||||||
self.structured_output_manager.grammar_init(req)
|
self.structured_output_manager.grammar_init(req)
|
||||||
|
|
||||||
if req.raw_kv_transfer_params is not None:
|
if req.kv_transfer_params is not None and (
|
||||||
if (kv_connector := self.scheduler.get_kv_connector()):
|
not self.scheduler.get_kv_connector()):
|
||||||
# Parse raw KV transfer params via connector.
|
logger.warning("Got kv_transfer_params, but no KVConnector found. "
|
||||||
kv_connector.set_kv_transfer_params(req)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"Got KVTransferParams, but no KVConnector found. "
|
|
||||||
"Disabling KVTransfer for this request.")
|
"Disabling KVTransfer for this request.")
|
||||||
|
|
||||||
self.scheduler.add_request(req)
|
self.scheduler.add_request(req)
|
||||||
|
|||||||
@ -3,7 +3,6 @@
|
|||||||
import enum
|
import enum
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVTransferParams
|
|
||||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.utils import is_list_of
|
from vllm.utils import is_list_of
|
||||||
@ -62,14 +61,10 @@ class Request:
|
|||||||
self.num_encoder_inputs = len(self.mm_inputs)
|
self.num_encoder_inputs = len(self.mm_inputs)
|
||||||
self.has_encoder_inputs = self.num_encoder_inputs > 0
|
self.has_encoder_inputs = self.num_encoder_inputs > 0
|
||||||
|
|
||||||
# P/D: KV transfer parameters (raw and parsed).
|
# P/D: Connector-specific KV transfer parameters.
|
||||||
raw_params = (None if sampling_params.extra_args is None
|
kv_params = (None if sampling_params.extra_args is None else
|
||||||
else sampling_params.extra_args.get(
|
sampling_params.extra_args.get("kv_transfer_params"))
|
||||||
"kv_transfer_params", None))
|
self.kv_transfer_params: Optional[dict[str, Any]] = kv_params
|
||||||
self.raw_kv_transfer_params: Optional[dict[str, Any]] = raw_params
|
|
||||||
# Each connector parses the raw dictionary and sets this
|
|
||||||
# attr the first time that the request is processed.
|
|
||||||
self.kv_transfer_params: Optional[KVTransferParams] = None
|
|
||||||
|
|
||||||
# Sanity check
|
# Sanity check
|
||||||
assert len(self.mm_inputs) == len(self.mm_positions)
|
assert len(self.mm_inputs) == len(self.mm_positions)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user