mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-24 18:17:55 +08:00
fix mypy for core and engine
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
2ba60ec7fe
commit
8244ff7fee
@ -36,12 +36,15 @@ FILES = [
|
|||||||
"vllm/transformers_utils",
|
"vllm/transformers_utils",
|
||||||
"vllm/triton_utils",
|
"vllm/triton_utils",
|
||||||
"vllm/usage",
|
"vllm/usage",
|
||||||
|
"vllm/v1/core",
|
||||||
|
"vllm/v1/engine",
|
||||||
]
|
]
|
||||||
|
|
||||||
# After fixing errors resulting from changing follow_imports
|
# After fixing errors resulting from changing follow_imports
|
||||||
# from "skip" to "silent", move the following directories to FILES
|
# from "skip" to "silent", move the following directories to FILES
|
||||||
SEPARATE_GROUPS = [
|
SEPARATE_GROUPS = [
|
||||||
"tests",
|
"tests",
|
||||||
|
# v0 related
|
||||||
"vllm/attention",
|
"vllm/attention",
|
||||||
"vllm/compilation",
|
"vllm/compilation",
|
||||||
"vllm/engine",
|
"vllm/engine",
|
||||||
@ -50,7 +53,16 @@ SEPARATE_GROUPS = [
|
|||||||
"vllm/model_executor",
|
"vllm/model_executor",
|
||||||
"vllm/plugins",
|
"vllm/plugins",
|
||||||
"vllm/worker",
|
"vllm/worker",
|
||||||
"vllm/v1",
|
# v1 related
|
||||||
|
"vllm/v1/attention",
|
||||||
|
"vllm/v1/executor",
|
||||||
|
"vllm/v1/kv_offload",
|
||||||
|
"vllm/v1/metrics",
|
||||||
|
"vllm/v1/pool",
|
||||||
|
"vllm/v1/sample",
|
||||||
|
"vllm/v1/spec_decode",
|
||||||
|
"vllm/v1/structured_output",
|
||||||
|
"vllm/v1/worker",
|
||||||
]
|
]
|
||||||
|
|
||||||
# TODO(woosuk): Include the code from Megatron and HuggingFace.
|
# TODO(woosuk): Include the code from Megatron and HuggingFace.
|
||||||
|
|||||||
@ -72,6 +72,7 @@ class EngineClient(ABC):
|
|||||||
lora_request: LoRARequest | None = None,
|
lora_request: LoRARequest | None = None,
|
||||||
trace_headers: Mapping[str, str] | None = None,
|
trace_headers: Mapping[str, str] | None = None,
|
||||||
priority: int = 0,
|
priority: int = 0,
|
||||||
|
truncate_prompt_tokens: int | None = None,
|
||||||
tokenization_kwargs: dict[str, Any] | None = None,
|
tokenization_kwargs: dict[str, Any] | None = None,
|
||||||
) -> AsyncGenerator[PoolingRequestOutput, None]:
|
) -> AsyncGenerator[PoolingRequestOutput, None]:
|
||||||
"""Generate outputs for a request from a pooling model."""
|
"""Generate outputs for a request from a pooling model."""
|
||||||
|
|||||||
@ -165,7 +165,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
self.kv_cache_manager = KVCacheManager(
|
self.kv_cache_manager = KVCacheManager(
|
||||||
kv_cache_config=kv_cache_config,
|
kv_cache_config=kv_cache_config,
|
||||||
max_model_len=self.max_model_len,
|
max_model_len=self.max_model_len,
|
||||||
enable_caching=self.cache_config.enable_prefix_caching,
|
enable_caching=bool(self.cache_config.enable_prefix_caching),
|
||||||
use_eagle=self.use_eagle,
|
use_eagle=self.use_eagle,
|
||||||
log_stats=self.log_stats,
|
log_stats=self.log_stats,
|
||||||
enable_kv_cache_events=self.enable_kv_cache_events,
|
enable_kv_cache_events=self.enable_kv_cache_events,
|
||||||
@ -392,7 +392,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
skipped_waiting_requests.prepend_request(request)
|
skipped_waiting_requests.prepend_request(request)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
num_external_computed_tokens = 0
|
num_external_computed_tokens: int | None = 0
|
||||||
load_kv_async = False
|
load_kv_async = False
|
||||||
|
|
||||||
# Get already-cached tokens.
|
# Get already-cached tokens.
|
||||||
@ -419,8 +419,8 @@ class Scheduler(SchedulerInterface):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Total computed tokens (local + external).
|
# Total computed tokens (local + external).
|
||||||
num_computed_tokens = (
|
num_computed_tokens = num_new_local_computed_tokens + (
|
||||||
num_new_local_computed_tokens + num_external_computed_tokens
|
num_external_computed_tokens or 0
|
||||||
)
|
)
|
||||||
# KVTransfer: WAITING reqs have num_computed_tokens > 0
|
# KVTransfer: WAITING reqs have num_computed_tokens > 0
|
||||||
# after async KV recvs are completed.
|
# after async KV recvs are completed.
|
||||||
@ -434,6 +434,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
# KVTransfer: loading remote KV, do not allocate for new work.
|
# KVTransfer: loading remote KV, do not allocate for new work.
|
||||||
if load_kv_async:
|
if load_kv_async:
|
||||||
|
assert isinstance(num_external_computed_tokens, int)
|
||||||
assert num_external_computed_tokens > 0
|
assert num_external_computed_tokens > 0
|
||||||
num_new_tokens = 0
|
num_new_tokens = 0
|
||||||
# Number of tokens to be scheduled.
|
# Number of tokens to be scheduled.
|
||||||
@ -503,7 +504,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||||
request,
|
request,
|
||||||
num_new_tokens + num_external_computed_tokens,
|
num_new_tokens + (num_external_computed_tokens or 0),
|
||||||
num_new_local_computed_tokens,
|
num_new_local_computed_tokens,
|
||||||
new_computed_blocks,
|
new_computed_blocks,
|
||||||
num_lookahead_tokens=effective_lookahead_tokens,
|
num_lookahead_tokens=effective_lookahead_tokens,
|
||||||
@ -523,7 +524,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
self.connector.update_state_after_alloc(
|
self.connector.update_state_after_alloc(
|
||||||
request,
|
request,
|
||||||
new_computed_blocks + new_blocks,
|
new_computed_blocks + new_blocks,
|
||||||
num_external_computed_tokens,
|
num_external_computed_tokens or 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Request was already popped from self.waiting
|
# Request was already popped from self.waiting
|
||||||
@ -916,13 +917,13 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
|
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
|
||||||
spec_decoding_stats: SpecDecodingStats | None = None
|
spec_decoding_stats: SpecDecodingStats | None = None
|
||||||
kv_connector_stats = (
|
kv_connector_stats: KVConnectorStats | None = (
|
||||||
kv_connector_output.kv_connector_stats if kv_connector_output else None
|
kv_connector_output.kv_connector_stats if kv_connector_output else None
|
||||||
)
|
)
|
||||||
if kv_connector_stats and self.connector:
|
if kv_connector_stats and self.connector:
|
||||||
stats = self.connector.get_kv_connector_stats()
|
kv_stats = self.connector.get_kv_connector_stats()
|
||||||
if stats:
|
if kv_stats:
|
||||||
kv_connector_stats = kv_connector_stats.aggregate(stats)
|
kv_connector_stats = kv_connector_stats.aggregate(kv_stats)
|
||||||
|
|
||||||
failed_kv_load_req_ids = None
|
failed_kv_load_req_ids = None
|
||||||
if kv_connector_output and kv_connector_output.invalid_block_ids:
|
if kv_connector_output and kv_connector_output.invalid_block_ids:
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import socket
|
|||||||
import time
|
import time
|
||||||
from collections.abc import AsyncGenerator, Iterable, Mapping
|
from collections.abc import AsyncGenerator, Iterable, Mapping
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -122,10 +122,9 @@ class AsyncLLM(EngineClient):
|
|||||||
self.output_processor = OutputProcessor(
|
self.output_processor = OutputProcessor(
|
||||||
self.tokenizer, log_stats=self.log_stats
|
self.tokenizer, log_stats=self.log_stats
|
||||||
)
|
)
|
||||||
if self.observability_config.otlp_traces_endpoint is not None:
|
endpoint = getattr(self.observability_config, "otlp_traces_endpoint", None)
|
||||||
tracer = init_tracer(
|
if endpoint is not None:
|
||||||
"vllm.llm_engine", self.observability_config.otlp_traces_endpoint
|
tracer = init_tracer("vllm.llm_engine", endpoint)
|
||||||
)
|
|
||||||
self.output_processor.tracer = tracer
|
self.output_processor.tracer = tracer
|
||||||
|
|
||||||
# EngineCore (starts the engine in background process).
|
# EngineCore (starts the engine in background process).
|
||||||
@ -257,7 +256,9 @@ class AsyncLLM(EngineClient):
|
|||||||
if engine_core := getattr(self, "engine_core", None):
|
if engine_core := getattr(self, "engine_core", None):
|
||||||
engine_core.shutdown()
|
engine_core.shutdown()
|
||||||
|
|
||||||
cancel_task_threadsafe(getattr(self, "output_handler", None))
|
handler = getattr(self, "output_handler", None)
|
||||||
|
if handler is not None:
|
||||||
|
cancel_task_threadsafe(handler)
|
||||||
|
|
||||||
async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||||
return await self.engine_core.get_supported_tasks_async()
|
return await self.engine_core.get_supported_tasks_async()
|
||||||
@ -305,7 +306,10 @@ class AsyncLLM(EngineClient):
|
|||||||
priority,
|
priority,
|
||||||
data_parallel_rank,
|
data_parallel_rank,
|
||||||
)
|
)
|
||||||
prompt_text = prompt if isinstance(prompt, str) else prompt.get("prompt")
|
if isinstance(prompt, str):
|
||||||
|
prompt_text = prompt
|
||||||
|
elif isinstance(prompt, Mapping):
|
||||||
|
prompt_text = cast(str | None, prompt.get("prompt"))
|
||||||
|
|
||||||
if is_pooling or params.n == 1:
|
if is_pooling or params.n == 1:
|
||||||
await self._add_request(request, prompt_text, None, 0, queue)
|
await self._add_request(request, prompt_text, None, 0, queue)
|
||||||
@ -427,6 +431,7 @@ class AsyncLLM(EngineClient):
|
|||||||
# Note: both OutputProcessor and EngineCore handle their
|
# Note: both OutputProcessor and EngineCore handle their
|
||||||
# own request cleanup based on finished.
|
# own request cleanup based on finished.
|
||||||
finished = out.finished
|
finished = out.finished
|
||||||
|
assert isinstance(out, RequestOutput)
|
||||||
yield out
|
yield out
|
||||||
|
|
||||||
# If the request is disconnected by the client, generate()
|
# If the request is disconnected by the client, generate()
|
||||||
|
|||||||
@ -467,9 +467,10 @@ class EngineCore:
|
|||||||
self,
|
self,
|
||||||
tensorizer_config,
|
tensorizer_config,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model_executor.save_tensorized_model(
|
if hasattr(self.model_executor, "save_tensorized_model"):
|
||||||
tensorizer_config=tensorizer_config,
|
self.model_executor.save_tensorized_model( # type: ignore[attr-defined]
|
||||||
)
|
tensorizer_config=tensorizer_config,
|
||||||
|
)
|
||||||
|
|
||||||
def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, int]:
|
def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, int]:
|
||||||
"""Preprocess the request.
|
"""Preprocess the request.
|
||||||
@ -1089,6 +1090,7 @@ class DPEngineCoreProc(EngineCoreProc):
|
|||||||
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
|
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
|
||||||
|
|
||||||
assert dp_size > 1
|
assert dp_size > 1
|
||||||
|
assert local_dp_rank is not None
|
||||||
assert 0 <= local_dp_rank <= dp_rank < dp_size
|
assert 0 <= local_dp_rank <= dp_rank < dp_size
|
||||||
|
|
||||||
if vllm_config.kv_transfer_config is not None:
|
if vllm_config.kv_transfer_config is not None:
|
||||||
@ -1235,7 +1237,8 @@ class DPEngineCoreProc(EngineCoreProc):
|
|||||||
parallel_config.data_parallel_master_port
|
parallel_config.data_parallel_master_port
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model_executor.reinitialize_distributed(reconfig_request)
|
if hasattr(self.model_executor, "reinitialize_distributed"):
|
||||||
|
self.model_executor.reinitialize_distributed(reconfig_request) # type: ignore[attr-defined]
|
||||||
if reconfig_request.new_data_parallel_size > old_dp_size:
|
if reconfig_request.new_data_parallel_size > old_dp_size:
|
||||||
assert self.available_gpu_memory_for_kv_cache > 0
|
assert self.available_gpu_memory_for_kv_cache > 0
|
||||||
# pass available_gpu_memory_for_kv_cache from existing
|
# pass available_gpu_memory_for_kv_cache from existing
|
||||||
|
|||||||
@ -385,7 +385,7 @@ class BackgroundResources:
|
|||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
if in_loop(loop):
|
if loop is not None and in_loop(loop):
|
||||||
close_sockets_and_tasks()
|
close_sockets_and_tasks()
|
||||||
elif loop and not loop.is_closed():
|
elif loop and not loop.is_closed():
|
||||||
loop.call_soon_threadsafe(close_sockets_and_tasks)
|
loop.call_soon_threadsafe(close_sockets_and_tasks)
|
||||||
@ -1044,6 +1044,7 @@ class DPAsyncMPClient(AsyncMPClient):
|
|||||||
return
|
return
|
||||||
|
|
||||||
assert self.stats_update_address is not None
|
assert self.stats_update_address is not None
|
||||||
|
stats_addr: str = self.stats_update_address
|
||||||
assert len(self.engine_ranks_managed) > 0
|
assert len(self.engine_ranks_managed) > 0
|
||||||
# NOTE: running and waiting counts are all global from
|
# NOTE: running and waiting counts are all global from
|
||||||
# the Coordinator include all global EngineCores. This
|
# the Coordinator include all global EngineCores. This
|
||||||
@ -1054,9 +1055,7 @@ class DPAsyncMPClient(AsyncMPClient):
|
|||||||
|
|
||||||
async def run_engine_stats_update_task():
|
async def run_engine_stats_update_task():
|
||||||
with (
|
with (
|
||||||
make_zmq_socket(
|
make_zmq_socket(self.ctx, stats_addr, zmq.XSUB, linger=0) as socket,
|
||||||
self.ctx, self.stats_update_address, zmq.XSUB, linger=0
|
|
||||||
) as socket,
|
|
||||||
make_zmq_socket(
|
make_zmq_socket(
|
||||||
self.ctx, self.first_req_sock_addr, zmq.PAIR, bind=False, linger=0
|
self.ctx, self.first_req_sock_addr, zmq.PAIR, bind=False, linger=0
|
||||||
) as first_req_rcv_socket,
|
) as first_req_rcv_socket,
|
||||||
|
|||||||
@ -69,14 +69,21 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
|
|||||||
# Stop strings
|
# Stop strings
|
||||||
params = request.sampling_params
|
params = request.sampling_params
|
||||||
assert params is not None
|
assert params is not None
|
||||||
self.stop = stop = params.stop
|
stop_list: list[str]
|
||||||
|
if params.stop is None:
|
||||||
|
stop_list = []
|
||||||
|
elif isinstance(params.stop, str):
|
||||||
|
stop_list = [params.stop]
|
||||||
|
else:
|
||||||
|
stop_list = params.stop
|
||||||
|
self.stop = stop_list
|
||||||
self.min_tokens = params.min_tokens
|
self.min_tokens = params.min_tokens
|
||||||
self.include_stop_str_in_output = params.include_stop_str_in_output
|
self.include_stop_str_in_output = params.include_stop_str_in_output
|
||||||
|
|
||||||
# Number of chars to hold back when stop strings are to be excluded
|
# Number of chars to hold back when stop strings are to be excluded
|
||||||
# from streamed output.
|
# from streamed output.
|
||||||
if stop and not self.include_stop_str_in_output:
|
if self.stop and not self.include_stop_str_in_output:
|
||||||
self.stop_buffer_length = max(len(s) for s in stop) - 1
|
self.stop_buffer_length = max(len(s) for s in self.stop) - 1
|
||||||
else:
|
else:
|
||||||
self.stop_buffer_length = 0
|
self.stop_buffer_length = 0
|
||||||
self._last_output_text_offset: int = 0
|
self._last_output_text_offset: int = 0
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
import time
|
import time
|
||||||
from collections.abc import Callable, Mapping
|
from collections.abc import Callable, Mapping
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing_extensions import TypeVar
|
from typing_extensions import TypeVar
|
||||||
@ -112,10 +112,9 @@ class LLMEngine:
|
|||||||
self.output_processor = OutputProcessor(
|
self.output_processor = OutputProcessor(
|
||||||
self.tokenizer, log_stats=self.log_stats
|
self.tokenizer, log_stats=self.log_stats
|
||||||
)
|
)
|
||||||
if self.observability_config.otlp_traces_endpoint is not None:
|
endpoint = getattr(self.observability_config, "otlp_traces_endpoint", None)
|
||||||
tracer = init_tracer(
|
if endpoint is not None:
|
||||||
"vllm.llm_engine", self.observability_config.otlp_traces_endpoint
|
tracer = init_tracer("vllm.llm_engine", endpoint)
|
||||||
)
|
|
||||||
self.output_processor.tracer = tracer
|
self.output_processor.tracer = tracer
|
||||||
|
|
||||||
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
|
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
|
||||||
@ -259,7 +258,10 @@ class LLMEngine:
|
|||||||
trace_headers,
|
trace_headers,
|
||||||
priority,
|
priority,
|
||||||
)
|
)
|
||||||
prompt_text = prompt if isinstance(prompt, str) else prompt.get("prompt")
|
if isinstance(prompt, str):
|
||||||
|
prompt_text = prompt
|
||||||
|
elif isinstance(prompt, Mapping):
|
||||||
|
prompt_text = cast(str | None, prompt.get("prompt"))
|
||||||
|
|
||||||
n = params.n if isinstance(params, SamplingParams) else 1
|
n = params.n if isinstance(params, SamplingParams) else 1
|
||||||
|
|
||||||
@ -316,7 +318,14 @@ class LLMEngine:
|
|||||||
)
|
)
|
||||||
self.do_log_stats_with_interval()
|
self.do_log_stats_with_interval()
|
||||||
|
|
||||||
return processed_outputs.request_outputs
|
ro = processed_outputs.request_outputs
|
||||||
|
if not ro:
|
||||||
|
return []
|
||||||
|
first = ro[0]
|
||||||
|
if isinstance(first, RequestOutput):
|
||||||
|
return [x for x in ro if isinstance(x, RequestOutput)]
|
||||||
|
else:
|
||||||
|
return [x for x in ro if isinstance(x, PoolingRequestOutput)]
|
||||||
|
|
||||||
def start_profile(self):
|
def start_profile(self):
|
||||||
self.engine_core.profile(True)
|
self.engine_core.profile(True)
|
||||||
|
|||||||
@ -47,7 +47,14 @@ class RequestOutputCollector:
|
|||||||
elif isinstance(self.output, (RequestOutput, PoolingRequestOutput)):
|
elif isinstance(self.output, (RequestOutput, PoolingRequestOutput)):
|
||||||
# This ensures that request outputs with different request indexes
|
# This ensures that request outputs with different request indexes
|
||||||
# (if n > 1) do not override each other.
|
# (if n > 1) do not override each other.
|
||||||
self.output.add(output, aggregate=self.aggregate)
|
if isinstance(self.output, RequestOutput) and isinstance(
|
||||||
|
output, RequestOutput
|
||||||
|
):
|
||||||
|
self.output.add(output, aggregate=self.aggregate)
|
||||||
|
elif isinstance(self.output, PoolingRequestOutput) and isinstance(
|
||||||
|
output, PoolingRequestOutput
|
||||||
|
):
|
||||||
|
self.output = output
|
||||||
|
|
||||||
async def get(self) -> RequestOutput | PoolingRequestOutput:
|
async def get(self) -> RequestOutput | PoolingRequestOutput:
|
||||||
"""Get operation blocks on put event."""
|
"""Get operation blocks on put event."""
|
||||||
@ -407,7 +414,7 @@ class OutputProcessor:
|
|||||||
within the loop below.
|
within the loop below.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
request_outputs: list[RequestOutput] | list[PoolingRequestOutput] = []
|
request_outputs: list[RequestOutput | PoolingRequestOutput] = []
|
||||||
reqs_to_abort: list[str] = []
|
reqs_to_abort: list[str] = []
|
||||||
for engine_core_output in engine_core_outputs:
|
for engine_core_output in engine_core_outputs:
|
||||||
req_id = engine_core_output.request_id
|
req_id = engine_core_output.request_id
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from typing import Optional
|
from typing import Optional, cast
|
||||||
|
|
||||||
from vllm.outputs import CompletionOutput
|
from vllm.outputs import CompletionOutput
|
||||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||||
@ -37,7 +37,7 @@ class ParentRequest:
|
|||||||
|
|
||||||
self.child_requests = set()
|
self.child_requests = set()
|
||||||
self.output_aggregator = (
|
self.output_aggregator = (
|
||||||
[None] * sampling_params.n
|
[cast(CompletionOutput, None)] * sampling_params.n
|
||||||
if (sampling_params.output_kind == RequestOutputKind.FINAL_ONLY)
|
if (sampling_params.output_kind == RequestOutputKind.FINAL_ONLY)
|
||||||
else []
|
else []
|
||||||
)
|
)
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, cast
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
|
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
|
||||||
@ -208,9 +208,9 @@ class Processor:
|
|||||||
enc = prompt.get("encoder_prompt")
|
enc = prompt.get("encoder_prompt")
|
||||||
dec = prompt.get("decoder_prompt")
|
dec = prompt.get("decoder_prompt")
|
||||||
if enc is not None:
|
if enc is not None:
|
||||||
_validate_single_prompt(enc)
|
_validate_single_prompt(cast(dict | str, enc))
|
||||||
if dec is not None:
|
if dec is not None:
|
||||||
_validate_single_prompt(dec)
|
_validate_single_prompt(cast(dict | str, dec))
|
||||||
else:
|
else:
|
||||||
_validate_single_prompt(prompt) # type: ignore[arg-type]
|
_validate_single_prompt(prompt) # type: ignore[arg-type]
|
||||||
|
|
||||||
@ -332,7 +332,7 @@ class Processor:
|
|||||||
if not mm_data:
|
if not mm_data:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
mm_uuids: MultiModalUUIDDict = {}
|
mm_uuids: dict[str, list[str | None] | str] = {}
|
||||||
for modality, data in mm_data.items():
|
for modality, data in mm_data.items():
|
||||||
n = len(data) if isinstance(data, list) else 1
|
n = len(data) if isinstance(data, list) else 1
|
||||||
mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)]
|
mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)]
|
||||||
@ -384,7 +384,9 @@ class Processor:
|
|||||||
# if provided.
|
# if provided.
|
||||||
self._validate_multi_modal_uuids(prompt)
|
self._validate_multi_modal_uuids(prompt)
|
||||||
if isinstance(prompt, dict):
|
if isinstance(prompt, dict):
|
||||||
mm_uuids = prompt.get("multi_modal_uuids")
|
mm_uuids = cast(
|
||||||
|
MultiModalUUIDDict | None, prompt.get("multi_modal_uuids")
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
mm_uuids = None
|
mm_uuids = None
|
||||||
|
|
||||||
@ -410,20 +412,13 @@ class Processor:
|
|||||||
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
|
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
|
||||||
self._validate_model_inputs(encoder_inputs, decoder_inputs)
|
self._validate_model_inputs(encoder_inputs, decoder_inputs)
|
||||||
|
|
||||||
# Mypy does not always properly infer the types of some elements of
|
# Mypy can be conservative for TypedDict unions; normalize access.
|
||||||
# discriminated unions of TypedDicts, because of how it handles
|
if decoder_inputs["type"] == "embeds":
|
||||||
# inheritance of TypedDict. If we explicitly extract the items we want
|
prompt_token_ids = None
|
||||||
# we can avoid type errors from using `dict.get` later in the method.
|
prompt_embeds = decoder_inputs["prompt_embeds"]
|
||||||
prompt_token_ids = (
|
else:
|
||||||
decoder_inputs["prompt_token_ids"]
|
prompt_token_ids = decoder_inputs["prompt_token_ids"]
|
||||||
if decoder_inputs["type"] != "embeds"
|
prompt_embeds = None
|
||||||
else None
|
|
||||||
)
|
|
||||||
prompt_embeds = (
|
|
||||||
decoder_inputs["prompt_embeds"]
|
|
||||||
if decoder_inputs["type"] == "embeds"
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
sampling_params = None
|
sampling_params = None
|
||||||
pooling_params = None
|
pooling_params = None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user