[V1] Get supported tasks from model runner instead of model config (#21585)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-07-25 20:36:45 +08:00 committed by GitHub
parent 5c3f2628d5
commit 46d81d6951
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 200 additions and 54 deletions

View File

@ -14,6 +14,7 @@ from pydantic import ValidationError
from tqdm.auto import tqdm from tqdm.auto import tqdm
from typing_extensions import TypeVar, deprecated from typing_extensions import TypeVar, deprecated
import vllm.envs as envs
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, BeamSearchSequence,
create_sort_beams_key_function) create_sort_beams_key_function)
@ -44,9 +45,10 @@ from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput, from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
PoolingRequestOutput, RequestOutput, PoolingRequestOutput, RequestOutput,
ScoringRequestOutput) ScoringRequestOutput)
from vllm.pooling_params import PoolingParams, PoolingTask from vllm.pooling_params import PoolingParams
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
RequestOutputKind, SamplingParams) RequestOutputKind, SamplingParams)
from vllm.tasks import PoolingTask
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer) get_cached_tokenizer)
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
@ -277,6 +279,16 @@ class LLM:
self.request_counter = Counter() self.request_counter = Counter()
self.default_sampling_params: Union[dict[str, Any], None] = None self.default_sampling_params: Union[dict[str, Any], None] = None
if envs.VLLM_USE_V1:
supported_tasks = self.llm_engine \
.get_supported_tasks() # type: ignore
else:
supported_tasks = self.llm_engine.model_config.supported_tasks
logger.info("Supported_tasks: %s", supported_tasks)
self.supported_tasks = supported_tasks
def get_tokenizer( def get_tokenizer(
self, self,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
@ -1170,8 +1182,7 @@ class LLM:
A list of `EmbeddingRequestOutput` objects containing the A list of `EmbeddingRequestOutput` objects containing the
embedding vectors in the same order as the input prompts. embedding vectors in the same order as the input prompts.
""" """
model_config = self.llm_engine.model_config if "embed" not in self.supported_tasks:
if "embed" not in model_config.supported_tasks:
raise ValueError("Embedding API is not supported by this model. " raise ValueError("Embedding API is not supported by this model. "
"Please set `--task embed`.") "Please set `--task embed`.")
@ -1215,8 +1226,7 @@ class LLM:
A list of `ClassificationRequestOutput` objects containing the A list of `ClassificationRequestOutput` objects containing the
embedding vectors in the same order as the input prompts. embedding vectors in the same order as the input prompts.
""" """
model_config = self.llm_engine.model_config if "classify" not in self.supported_tasks:
if "classify" not in model_config.supported_tasks:
raise ValueError( raise ValueError(
"Classification API is not supported by this model. " "Classification API is not supported by this model. "
"Please set `--task classify`.") "Please set `--task classify`.")
@ -1397,8 +1407,8 @@ class LLM:
raise ValueError(" ".join(messages)) raise ValueError(" ".join(messages))
if all(t not in model_config.supported_tasks supported_tasks = self.supported_tasks
for t in ("embed", "classify")): if all(t not in supported_tasks for t in ("embed", "classify")):
raise ValueError("Score API is not supported by this model. " raise ValueError("Score API is not supported by this model. "
"Please set `--task embed` or `--task classify`.") "Please set `--task embed` or `--task classify`.")

View File

@ -1586,6 +1586,14 @@ async def init_app_state(
state.vllm_config = vllm_config state.vllm_config = vllm_config
model_config = vllm_config.model_config model_config = vllm_config.model_config
if envs.VLLM_USE_V1:
supported_tasks = await engine_client \
.get_supported_tasks() # type: ignore
else:
supported_tasks = model_config.supported_tasks
logger.info("Supported_tasks: %s", supported_tasks)
resolved_chat_template = load_chat_template(args.chat_template) resolved_chat_template = load_chat_template(args.chat_template)
if resolved_chat_template is not None: if resolved_chat_template is not None:
# Get the tokenizer to check official template # Get the tokenizer to check official template
@ -1647,7 +1655,7 @@ async def init_app_state(
reasoning_parser=args.reasoning_parser, reasoning_parser=args.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage, enable_force_include_usage=args.enable_force_include_usage,
) if "generate" in model_config.supported_tasks else None ) if "generate" in supported_tasks else None
state.openai_serving_chat = OpenAIServingChat( state.openai_serving_chat = OpenAIServingChat(
engine_client, engine_client,
model_config, model_config,
@ -1664,7 +1672,7 @@ async def init_app_state(
reasoning_parser=args.reasoning_parser, reasoning_parser=args.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage, enable_force_include_usage=args.enable_force_include_usage,
) if "generate" in model_config.supported_tasks else None ) if "generate" in supported_tasks else None
state.openai_serving_completion = OpenAIServingCompletion( state.openai_serving_completion = OpenAIServingCompletion(
engine_client, engine_client,
model_config, model_config,
@ -1673,7 +1681,7 @@ async def init_app_state(
return_tokens_as_token_ids=args.return_tokens_as_token_ids, return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage, enable_force_include_usage=args.enable_force_include_usage,
) if "generate" in model_config.supported_tasks else None ) if "generate" in supported_tasks else None
state.openai_serving_pooling = OpenAIServingPooling( state.openai_serving_pooling = OpenAIServingPooling(
engine_client, engine_client,
model_config, model_config,
@ -1681,7 +1689,7 @@ async def init_app_state(
request_logger=request_logger, request_logger=request_logger,
chat_template=resolved_chat_template, chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format, chat_template_content_format=args.chat_template_content_format,
) if "encode" in model_config.supported_tasks else None ) if "encode" in supported_tasks else None
state.openai_serving_embedding = OpenAIServingEmbedding( state.openai_serving_embedding = OpenAIServingEmbedding(
engine_client, engine_client,
model_config, model_config,
@ -1689,24 +1697,22 @@ async def init_app_state(
request_logger=request_logger, request_logger=request_logger,
chat_template=resolved_chat_template, chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format, chat_template_content_format=args.chat_template_content_format,
) if "embed" in model_config.supported_tasks else None ) if "embed" in supported_tasks else None
state.openai_serving_classification = ServingClassification( state.openai_serving_classification = ServingClassification(
engine_client, engine_client,
model_config, model_config,
state.openai_serving_models, state.openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
) if "classify" in model_config.supported_tasks else None ) if "classify" in supported_tasks else None
enable_serving_reranking = ("classify" in model_config.supported_tasks enable_serving_reranking = ("classify" in supported_tasks and getattr(
and getattr(model_config.hf_config, model_config.hf_config, "num_labels", 0) == 1)
"num_labels", 0) == 1)
state.openai_serving_scores = ServingScores( state.openai_serving_scores = ServingScores(
engine_client, engine_client,
model_config, model_config,
state.openai_serving_models, state.openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
) if ("embed" in model_config.supported_tasks ) if ("embed" in supported_tasks or enable_serving_reranking) else None
or enable_serving_reranking) else None
state.openai_serving_tokenization = OpenAIServingTokenization( state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client, engine_client,
@ -1721,13 +1727,13 @@ async def init_app_state(
model_config, model_config,
state.openai_serving_models, state.openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
) if "transcription" in model_config.supported_tasks else None ) if "transcription" in supported_tasks else None
state.openai_serving_translation = OpenAIServingTranslation( state.openai_serving_translation = OpenAIServingTranslation(
engine_client, engine_client,
model_config, model_config,
state.openai_serving_models, state.openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
) if "transcription" in model_config.supported_tasks else None ) if "transcription" in supported_tasks else None
state.task = model_config.task state.task = model_config.task
state.enable_server_load_tracking = args.enable_server_load_tracking state.enable_server_load_tracking = args.enable_server_load_tracking

View File

@ -14,6 +14,7 @@ import torch
from prometheus_client import start_http_server from prometheus_client import start_http_server
from tqdm import tqdm from tqdm import tqdm
import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
@ -335,6 +336,14 @@ async def run_batch(
model_config = vllm_config.model_config model_config = vllm_config.model_config
if envs.VLLM_USE_V1:
supported_tasks = await engine_client \
.get_supported_tasks() # type: ignore
else:
supported_tasks = model_config.supported_tasks
logger.info("Supported_tasks: %s", supported_tasks)
# Create the openai serving objects. # Create the openai serving objects.
openai_serving_models = OpenAIServingModels( openai_serving_models = OpenAIServingModels(
engine_client=engine_client, engine_client=engine_client,
@ -351,7 +360,7 @@ async def run_batch(
chat_template=None, chat_template=None,
chat_template_content_format="auto", chat_template_content_format="auto",
enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_prompt_tokens_details=args.enable_prompt_tokens_details,
) if "generate" in model_config.supported_tasks else None ) if "generate" in supported_tasks else None
openai_serving_embedding = OpenAIServingEmbedding( openai_serving_embedding = OpenAIServingEmbedding(
engine_client, engine_client,
model_config, model_config,
@ -359,19 +368,17 @@ async def run_batch(
request_logger=request_logger, request_logger=request_logger,
chat_template=None, chat_template=None,
chat_template_content_format="auto", chat_template_content_format="auto",
) if "embed" in model_config.supported_tasks else None ) if "embed" in supported_tasks else None
enable_serving_reranking = ("classify" in model_config.supported_tasks enable_serving_reranking = ("classify" in supported_tasks and getattr(
and getattr(model_config.hf_config, model_config.hf_config, "num_labels", 0) == 1)
"num_labels", 0) == 1)
openai_serving_scores = ServingScores( openai_serving_scores = ServingScores(
engine_client, engine_client,
model_config, model_config,
openai_serving_models, openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
) if ("embed" in model_config.supported_tasks ) if ("embed" in supported_tasks or enable_serving_reranking) else None
or enable_serving_reranking) else None
tracker = BatchProgressTracker() tracker = BatchProgressTracker()
logger.info("Reading batch from %s...", args.input_file) logger.info("Reading batch from %s...", args.input_file)

View File

@ -16,8 +16,8 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.pooling_params import PoolingTask
from vllm.sequence import ExecuteModelRequest, PoolerOutput from vllm.sequence import ExecuteModelRequest, PoolerOutput
from vllm.tasks import SupportedTask
from vllm.utils import make_async from vllm.utils import make_async
from vllm.worker.worker_base import WorkerBase from vllm.worker.worker_base import WorkerBase
@ -136,9 +136,9 @@ class ExecutorBase(ABC):
return self.collective_rpc(rpc_func) return self.collective_rpc(rpc_func)
@cached_property # Avoid unnecessary RPC calls @cached_property # Avoid unnecessary RPC calls
def supported_pooling_tasks(self) -> tuple[PoolingTask, ...]: def supported_tasks(self) -> tuple[SupportedTask, ...]:
output = self.collective_rpc("get_supported_pooling_tasks") output = self.collective_rpc("get_supported_tasks")
return tuple({task for tasks in output for task in tasks}) return output[0]
def execute_model( def execute_model(
self, execute_model_req: ExecuteModelRequest self, execute_model_req: ExecuteModelRequest

View File

@ -16,8 +16,9 @@ from vllm.config import ModelConfig, PoolerConfig
from vllm.model_executor.pooling_metadata import ( # noqa: E501 from vllm.model_executor.pooling_metadata import ( # noqa: E501
PoolingMetadata as V0PoolingMetadata) PoolingMetadata as V0PoolingMetadata)
from vllm.model_executor.pooling_metadata import PoolingTensors from vllm.model_executor.pooling_metadata import PoolingTensors
from vllm.pooling_params import PoolingParams, PoolingTask from vllm.pooling_params import PoolingParams
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
from vllm.tasks import PoolingTask
from vllm.utils import resolve_obj_by_qualname from vllm.utils import resolve_obj_by_qualname
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata

View File

@ -26,8 +26,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.pooling_params import PoolingTask
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask
from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix

View File

@ -16,8 +16,8 @@ from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
get_prompt_token_ids) get_prompt_token_ids)
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.pooling_params import PoolingTask
from vllm.sequence import PoolerOutput from vllm.sequence import PoolerOutput
from vllm.tasks import PoolingTask
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from .interfaces import SupportsV0Only from .interfaces import SupportsV0Only

View File

@ -23,8 +23,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.pooling_params import PoolingTask
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask
from .interfaces import SupportsCrossEncoding, SupportsV0Only from .interfaces import SupportsCrossEncoding, SupportsV0Only
from .utils import WeightsMapper, maybe_prefix from .utils import WeightsMapper, maybe_prefix

View File

@ -1,17 +1,16 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Literal, Optional from typing import TYPE_CHECKING, Optional
import msgspec import msgspec
from vllm.sampling_params import RequestOutputKind from vllm.sampling_params import RequestOutputKind
from vllm.tasks import PoolingTask
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig
PoolingTask = Literal["encode", "embed", "classify", "score"]
class PoolingParams( class PoolingParams(
msgspec.Struct, msgspec.Struct,

11
vllm/tasks.py Normal file
View File

@ -0,0 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Literal, get_args
GenerationTask = Literal["generate", "transcription"]
GENERATION_TASKS = get_args(GenerationTask)
PoolingTask = Literal["encode", "embed", "classify", "score"]
POOLING_TASKS = get_args(PoolingTask)
SupportedTask = Literal[GenerationTask, PoolingTask]

View File

@ -21,6 +21,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
from vllm.transformers_utils.config import ( from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value) maybe_register_config_serialize_by_value)
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
@ -211,6 +212,9 @@ class AsyncLLM(EngineClient):
if handler := getattr(self, "output_handler", None): if handler := getattr(self, "output_handler", None):
handler.cancel() handler.cancel()
async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return await self.engine_core.get_supported_tasks_async()
async def add_request( async def add_request(
self, self,
request_id: str, request_id: str,

View File

@ -23,6 +23,7 @@ from vllm.executor.multiproc_worker_utils import _add_prefix
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logging_utils.dump_input import dump_engine_exception from vllm.logging_utils.dump_input import dump_engine_exception
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.transformers_utils.config import ( from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value) maybe_register_config_serialize_by_value)
from vllm.utils import (bind_process_name, make_zmq_socket, from vllm.utils import (bind_process_name, make_zmq_socket,
@ -195,11 +196,17 @@ class EngineCore:
"warmup model) took %.2f seconds"), elapsed) "warmup model) took %.2f seconds"), elapsed)
return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.model_executor.supported_tasks
def add_request(self, request: EngineCoreRequest): def add_request(self, request: EngineCoreRequest):
"""Add request to the scheduler.""" """Add request to the scheduler."""
if pooling_params := request.pooling_params: if pooling_params := request.pooling_params:
supported_pooling_tasks = ( supported_pooling_tasks = [
self.model_executor.supported_pooling_tasks) task for task in self.get_supported_tasks()
if task in POOLING_TASKS
]
if pooling_params.task not in supported_pooling_tasks: if pooling_params.task not in supported_pooling_tasks:
raise ValueError(f"Unsupported task: {pooling_params.task!r} " raise ValueError(f"Unsupported task: {pooling_params.task!r} "
f"Supported tasks: {supported_pooling_tasks}") f"Supported tasks: {supported_pooling_tasks}")

View File

@ -21,6 +21,7 @@ import zmq.asyncio
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.tasks import SupportedTask
from vllm.utils import get_open_port, get_open_zmq_inproc_path, make_zmq_socket from vllm.utils import get_open_port, get_open_zmq_inproc_path, make_zmq_socket
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, EngineCoreRequestType,
@ -104,6 +105,9 @@ class EngineCoreClient(ABC):
def get_output(self) -> EngineCoreOutputs: def get_output(self) -> EngineCoreOutputs:
raise NotImplementedError raise NotImplementedError
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
raise NotImplementedError
def add_request(self, request: EngineCoreRequest) -> None: def add_request(self, request: EngineCoreRequest) -> None:
raise NotImplementedError raise NotImplementedError
@ -170,6 +174,9 @@ class EngineCoreClient(ABC):
async def get_output_async(self) -> EngineCoreOutputs: async def get_output_async(self) -> EngineCoreOutputs:
raise NotImplementedError raise NotImplementedError
async def get_supported_tasks_async(self) -> tuple[SupportedTask, ...]:
raise NotImplementedError
async def add_request_async(self, request: EngineCoreRequest) -> None: async def add_request_async(self, request: EngineCoreRequest) -> None:
raise NotImplementedError raise NotImplementedError
@ -238,6 +245,9 @@ class InprocClient(EngineCoreClient):
outputs, _ = self.engine_core.step() outputs, _ = self.engine_core.step()
return outputs.get(0) or EngineCoreOutputs() return outputs.get(0) or EngineCoreOutputs()
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.engine_core.get_supported_tasks()
def add_request(self, request: EngineCoreRequest) -> None: def add_request(self, request: EngineCoreRequest) -> None:
self.engine_core.add_request(request) self.engine_core.add_request(request)
@ -608,6 +618,9 @@ class SyncMPClient(MPClient):
return future.result() return future.result()
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.call_utility("get_supported_tasks")
def add_request(self, request: EngineCoreRequest) -> None: def add_request(self, request: EngineCoreRequest) -> None:
if self.is_dp: if self.is_dp:
self.engines_running = True self.engines_running = True
@ -802,6 +815,9 @@ class AsyncMPClient(MPClient):
self._ensure_output_queue_task() self._ensure_output_queue_task()
return await future return await future
async def get_supported_tasks_async(self) -> tuple[SupportedTask, ...]:
return await self.call_utility_async("get_supported_tasks")
async def add_request_async(self, request: EngineCoreRequest) -> None: async def add_request_async(self, request: EngineCoreRequest) -> None:
request.client_index = self.client_index request.client_index = self.client_index
await self._send_input(EngineCoreRequestType.ADD, request) await self._send_input(EngineCoreRequestType.ADD, request)

View File

@ -18,6 +18,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
from vllm.transformers_utils.tokenizer_group import ( from vllm.transformers_utils.tokenizer_group import (
TokenizerGroup, init_tokenizer_from_configs) TokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
@ -176,6 +177,9 @@ class LLMEngine:
def validate_outputs(cls, outputs, output_type): def validate_outputs(cls, outputs, output_type):
return outputs return outputs
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.engine_core.get_supported_tasks()
def abort_request(self, request_ids: list[str]) -> None: def abort_request(self, request_ids: list[str]) -> None:
"""Remove request_ids from EngineCore and Detokenizer.""" """Remove request_ids from EngineCore and Detokenizer."""

View File

@ -30,15 +30,17 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
from vllm.model_executor.models.interfaces import is_mixture_of_experts from vllm.model_executor.models.interfaces import (is_mixture_of_experts,
from vllm.model_executor.models.interfaces_base import (VllmModelForPooling, supports_transcription)
is_pooling_model) from vllm.model_executor.models.interfaces_base import (
VllmModelForPooling, is_pooling_model, is_text_generation_model)
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.pooling_params import PoolingParams, PoolingTask from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingType from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size, GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
is_pin_memory_available, round_up) is_pin_memory_available, round_up)
@ -1153,6 +1155,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
return self.model return self.model
def get_supported_generation_tasks(self) -> list[GenerationTask]:
model = self.get_model()
supported_tasks = list[GenerationTask]()
if is_text_generation_model(model):
supported_tasks.append("generate")
if supports_transcription(model):
if model.supports_transcription_only:
return ["transcription"]
supported_tasks.append("transcription")
return supported_tasks
def get_supported_pooling_tasks(self) -> list[PoolingTask]: def get_supported_pooling_tasks(self) -> list[PoolingTask]:
model = self.get_model() model = self.get_model()
if not is_pooling_model(model): if not is_pooling_model(model):
@ -1160,6 +1177,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return list(model.pooler.get_supported_tasks()) return list(model.pooler.get_supported_tasks())
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
tasks = list[SupportedTask]()
if self.model_config.runner_type == "generate":
tasks.extend(self.get_supported_generation_tasks())
if self.model_config.runner_type == "pooling":
tasks.extend(self.get_supported_pooling_tasks())
return tuple(tasks)
def apply_grammar_bitmask( def apply_grammar_bitmask(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",

View File

@ -23,8 +23,8 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.pooling_params import PoolingTask
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tasks import SupportedTask
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
@ -320,8 +320,8 @@ class Worker(WorkerBase):
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
return self.model_runner.get_model() return self.model_runner.get_model()
def get_supported_pooling_tasks(self) -> list[PoolingTask]: def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.model_runner.get_supported_pooling_tasks() return self.model_runner.get_supported_tasks()
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(

View File

@ -27,13 +27,15 @@ from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA from vllm.lora.layers import BaseLayerWithLoRA
from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.tpu import TPUModelLoader from vllm.model_executor.model_loader.tpu import TPUModelLoader
from vllm.model_executor.models.interfaces_base import is_pooling_model from vllm.model_executor.models.interfaces import supports_transcription
from vllm.model_executor.models.interfaces_base import (
is_pooling_model, is_text_generation_model)
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs, from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
PlaceholderRange) PlaceholderRange)
from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.pooling_params import PoolingTask
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import (LayerBlockType, cdiv, is_pin_memory_available, from vllm.utils import (LayerBlockType, cdiv, is_pin_memory_available,
prev_power_of_2) prev_power_of_2)
from vllm.v1.attention.backends.pallas import (TPU_STR_DTYPE_TO_TORCH_DTYPE, from vllm.v1.attention.backends.pallas import (TPU_STR_DTYPE_TO_TORCH_DTYPE,
@ -489,6 +491,21 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
return self.model return self.model
def get_supported_generation_tasks(self) -> list[GenerationTask]:
model = self.get_model()
supported_tasks = list[GenerationTask]()
if is_text_generation_model(model):
supported_tasks.append("generate")
if supports_transcription(model):
if model.supports_transcription_only:
return ["transcription"]
supported_tasks.append("transcription")
return supported_tasks
def get_supported_pooling_tasks(self) -> list[PoolingTask]: def get_supported_pooling_tasks(self) -> list[PoolingTask]:
model = self.get_model() model = self.get_model()
if not is_pooling_model(model): if not is_pooling_model(model):
@ -496,6 +513,16 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return list(model.pooler.get_supported_tasks()) return list(model.pooler.get_supported_tasks())
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
tasks = list[SupportedTask]()
if self.model_config.runner_type == "generate":
tasks.extend(self.get_supported_generation_tasks())
if self.model_config.runner_type == "pooling":
tasks.extend(self.get_supported_pooling_tasks())
return tuple(tasks)
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
""" """
Generates the KVCacheSpec by parsing the kv cache format from each Generates the KVCacheSpec by parsing the kv cache format from each

View File

@ -21,7 +21,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.pooling_params import PoolingTask from vllm.tasks import SupportedTask
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
@ -282,8 +282,8 @@ class TPUWorker:
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
return self.model_runner.get_model() return self.model_runner.get_model()
def get_supported_pooling_tasks(self) -> list[PoolingTask]: def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.model_runner.get_supported_pooling_tasks() return self.model_runner.get_supported_tasks()
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
return self.model_runner.get_kv_cache_spec() return self.model_runner.get_kv_cache_spec()

View File

@ -12,9 +12,11 @@ import torch.nn as nn
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.models.interfaces_base import is_pooling_model from vllm.model_executor.models.interfaces import supports_transcription
from vllm.pooling_params import PoolingTask from vllm.model_executor.models.interfaces_base import (
is_pooling_model, is_text_generation_model)
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
@ -224,6 +226,21 @@ class ModelRunnerBase(ABC, Generic[T]):
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
raise NotImplementedError raise NotImplementedError
def get_supported_generation_tasks(self) -> list[GenerationTask]:
model = self.get_model()
supported_tasks = list[GenerationTask]()
if is_text_generation_model(model):
supported_tasks.append("generate")
if supports_transcription(model):
if model.supports_transcription_only:
return ["transcription"]
supported_tasks.append("transcription")
return supported_tasks
def get_supported_pooling_tasks(self) -> list[PoolingTask]: def get_supported_pooling_tasks(self) -> list[PoolingTask]:
model = self.get_model() model = self.get_model()
if not is_pooling_model(model): if not is_pooling_model(model):
@ -231,6 +248,16 @@ class ModelRunnerBase(ABC, Generic[T]):
return list(model.pooler.get_supported_tasks()) return list(model.pooler.get_supported_tasks())
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
tasks = list[SupportedTask]()
if self.model_config.runner_type == "generate":
tasks.extend(self.get_supported_generation_tasks())
if self.model_config.runner_type == "pooling":
tasks.extend(self.get_supported_pooling_tasks())
return tuple(tasks)
def execute_model( def execute_model(
self, self,
model_input: T, model_input: T,