diff --git a/tests/test_config.py b/tests/test_config.py index 9e2bfb9e1b0ec..90d0c78c451f6 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -97,7 +97,6 @@ def test_auto_task(model_id, expected_runner_type, expected_convert_type, assert config.runner_type == expected_runner_type assert config.convert_type == expected_convert_type - assert expected_task in config.supported_tasks # Can remove once --task option is fully deprecated @@ -120,7 +119,6 @@ def test_score_task(model_id, expected_runner_type, expected_convert_type, assert config.runner_type == expected_runner_type assert config.convert_type == expected_convert_type - assert expected_task in config.supported_tasks # Can remove once --task option is fully deprecated @@ -137,7 +135,6 @@ def test_transcription_task(model_id, expected_runner_type, assert config.runner_type == expected_runner_type assert config.convert_type == expected_convert_type - assert expected_task in config.supported_tasks @pytest.mark.parametrize( diff --git a/vllm/config/model.py b/vllm/config/model.py index 1e0e4d8b3551e..0ded70388b8ac 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -14,7 +14,6 @@ from pydantic import (ConfigDict, SkipValidation, field_validator, model_validator) from pydantic.dataclasses import dataclass from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE -from typing_extensions import assert_never import vllm.envs as envs from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode, @@ -534,9 +533,6 @@ class ModelConfig: f"You can pass `--convert {convert_option} to adapt " "it into a pooling model.") - self.supported_tasks = self._get_supported_tasks( - architectures, self.runner_type, self.convert_type) - # Note: Initialize these attributes early because transformers fallback # may fail to load dynamic modules in child processes model_info, arch = registry.inspect_model_cls(architectures, self) @@ -834,27 +830,6 @@ class ModelConfig: return convert_type - def _get_supported_generation_tasks( - self, - architectures: list[str], - convert_type: ConvertType, - ) -> list[_ResolvedTask]: - registry = self.registry - - if registry.is_transcription_only_model(architectures, self): - return ["transcription"] - - # TODO: Use get_supported_generation_tasks once V0 is removed - supported_tasks = list[_ResolvedTask]() - if (registry.is_text_generation_model(architectures, self) - or convert_type in _RUNNER_CONVERTS["generate"]): - supported_tasks.append("generate") - - if registry.is_transcription_model(architectures, self): - supported_tasks.append("transcription") - - return supported_tasks - def _get_default_pooling_task( self, architectures: list[str], @@ -872,42 +847,6 @@ class ModelConfig: return "embed" - def _get_supported_pooling_tasks( - self, - architectures: list[str], - convert_type: ConvertType, - ) -> list[_ResolvedTask]: - registry = self.registry - - # TODO: Use get_supported_pooling_tasks once V0 is removed - supported_tasks = list[_ResolvedTask]() - if (registry.is_pooling_model(architectures, self) - or convert_type in _RUNNER_CONVERTS["pooling"]): - supported_tasks.append("encode") - - extra_task = (self._get_default_pooling_task(architectures) - if convert_type == "none" else convert_type) - supported_tasks.append(extra_task) - - return supported_tasks - - def _get_supported_tasks( - self, - architectures: list[str], - runner_type: RunnerType, - convert_type: ConvertType, - ) -> list[_ResolvedTask]: - if runner_type == "generate": - return self._get_supported_generation_tasks( - architectures, convert_type) - if runner_type == "pooling": - return self._get_supported_pooling_tasks(architectures, - convert_type) - if runner_type == "draft": - return ["draft"] - - assert_never(runner_type) - def _parse_quant_hf_config(self, hf_config: PretrainedConfig): quant_cfg = getattr(hf_config, "quantization_config", None) if quant_cfg is None: diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index e828ac04364ff..9aea74d0c8f3c 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -16,6 +16,7 @@ from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput from vllm.plugins.io_processors.interface import IOProcessor from vllm.pooling_params import PoolingParams from vllm.sampling_params import BeamSearchParams, SamplingParams +from vllm.tasks import SupportedTask from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import Device, collect_from_async_generator, random_uuid @@ -326,3 +327,7 @@ class EngineClient(ABC): kwargs: Optional[dict] = None): """Perform a collective RPC call to the given path.""" raise NotImplementedError + + async def get_supported_tasks(self) -> tuple[SupportedTask, ...]: + """Get supported tasks""" + raise NotImplementedError diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index b8ba7e81ef5f6..97cbda63bf426 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1609,11 +1609,7 @@ async def init_app_state( state.vllm_config = vllm_config model_config = vllm_config.model_config - if envs.VLLM_USE_V1: - supported_tasks = await engine_client \ - .get_supported_tasks() # type: ignore - else: - supported_tasks = model_config.supported_tasks + supported_tasks = await engine_client.get_supported_tasks() logger.info("Supported_tasks: %s", supported_tasks) diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index fa813550e520c..2568c21c4abe9 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -14,7 +14,6 @@ import torch from prometheus_client import start_http_server from tqdm import tqdm -import vllm.envs as envs from vllm.config import VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs, optional_type from vllm.engine.protocol import EngineClient @@ -334,12 +333,7 @@ async def run_batch( model_config = vllm_config.model_config - if envs.VLLM_USE_V1: - supported_tasks = await engine_client \ - .get_supported_tasks() # type: ignore - else: - supported_tasks = model_config.supported_tasks - + supported_tasks = await engine_client.get_supported_tasks() logger.info("Supported_tasks: %s", supported_tasks) # Create the openai serving objects.