mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 16:35:43 +08:00
[V0 deprecation] Remove unreachable model_config.supported_tasks (#25642)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
parent
eaeca3cd7f
commit
7f570f1caa
@ -97,7 +97,6 @@ def test_auto_task(model_id, expected_runner_type, expected_convert_type,
|
||||
|
||||
assert config.runner_type == expected_runner_type
|
||||
assert config.convert_type == expected_convert_type
|
||||
assert expected_task in config.supported_tasks
|
||||
|
||||
|
||||
# Can remove once --task option is fully deprecated
|
||||
@ -120,7 +119,6 @@ def test_score_task(model_id, expected_runner_type, expected_convert_type,
|
||||
|
||||
assert config.runner_type == expected_runner_type
|
||||
assert config.convert_type == expected_convert_type
|
||||
assert expected_task in config.supported_tasks
|
||||
|
||||
|
||||
# Can remove once --task option is fully deprecated
|
||||
@ -137,7 +135,6 @@ def test_transcription_task(model_id, expected_runner_type,
|
||||
|
||||
assert config.runner_type == expected_runner_type
|
||||
assert config.convert_type == expected_convert_type
|
||||
assert expected_task in config.supported_tasks
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@ -14,7 +14,6 @@ from pydantic import (ConfigDict, SkipValidation, field_validator,
|
||||
model_validator)
|
||||
from pydantic.dataclasses import dataclass
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
from typing_extensions import assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode,
|
||||
@ -534,9 +533,6 @@ class ModelConfig:
|
||||
f"You can pass `--convert {convert_option} to adapt "
|
||||
"it into a pooling model.")
|
||||
|
||||
self.supported_tasks = self._get_supported_tasks(
|
||||
architectures, self.runner_type, self.convert_type)
|
||||
|
||||
# Note: Initialize these attributes early because transformers fallback
|
||||
# may fail to load dynamic modules in child processes
|
||||
model_info, arch = registry.inspect_model_cls(architectures, self)
|
||||
@ -834,27 +830,6 @@ class ModelConfig:
|
||||
|
||||
return convert_type
|
||||
|
||||
def _get_supported_generation_tasks(
|
||||
self,
|
||||
architectures: list[str],
|
||||
convert_type: ConvertType,
|
||||
) -> list[_ResolvedTask]:
|
||||
registry = self.registry
|
||||
|
||||
if registry.is_transcription_only_model(architectures, self):
|
||||
return ["transcription"]
|
||||
|
||||
# TODO: Use get_supported_generation_tasks once V0 is removed
|
||||
supported_tasks = list[_ResolvedTask]()
|
||||
if (registry.is_text_generation_model(architectures, self)
|
||||
or convert_type in _RUNNER_CONVERTS["generate"]):
|
||||
supported_tasks.append("generate")
|
||||
|
||||
if registry.is_transcription_model(architectures, self):
|
||||
supported_tasks.append("transcription")
|
||||
|
||||
return supported_tasks
|
||||
|
||||
def _get_default_pooling_task(
|
||||
self,
|
||||
architectures: list[str],
|
||||
@ -872,42 +847,6 @@ class ModelConfig:
|
||||
|
||||
return "embed"
|
||||
|
||||
def _get_supported_pooling_tasks(
|
||||
self,
|
||||
architectures: list[str],
|
||||
convert_type: ConvertType,
|
||||
) -> list[_ResolvedTask]:
|
||||
registry = self.registry
|
||||
|
||||
# TODO: Use get_supported_pooling_tasks once V0 is removed
|
||||
supported_tasks = list[_ResolvedTask]()
|
||||
if (registry.is_pooling_model(architectures, self)
|
||||
or convert_type in _RUNNER_CONVERTS["pooling"]):
|
||||
supported_tasks.append("encode")
|
||||
|
||||
extra_task = (self._get_default_pooling_task(architectures)
|
||||
if convert_type == "none" else convert_type)
|
||||
supported_tasks.append(extra_task)
|
||||
|
||||
return supported_tasks
|
||||
|
||||
def _get_supported_tasks(
|
||||
self,
|
||||
architectures: list[str],
|
||||
runner_type: RunnerType,
|
||||
convert_type: ConvertType,
|
||||
) -> list[_ResolvedTask]:
|
||||
if runner_type == "generate":
|
||||
return self._get_supported_generation_tasks(
|
||||
architectures, convert_type)
|
||||
if runner_type == "pooling":
|
||||
return self._get_supported_pooling_tasks(architectures,
|
||||
convert_type)
|
||||
if runner_type == "draft":
|
||||
return ["draft"]
|
||||
|
||||
assert_never(runner_type)
|
||||
|
||||
def _parse_quant_hf_config(self, hf_config: PretrainedConfig):
|
||||
quant_cfg = getattr(hf_config, "quantization_config", None)
|
||||
if quant_cfg is None:
|
||||
|
||||
@ -16,6 +16,7 @@ from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
|
||||
from vllm.plugins.io_processors.interface import IOProcessor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import Device, collect_from_async_generator, random_uuid
|
||||
|
||||
@ -326,3 +327,7 @@ class EngineClient(ABC):
|
||||
kwargs: Optional[dict] = None):
|
||||
"""Perform a collective RPC call to the given path."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
"""Get supported tasks"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -1609,11 +1609,7 @@ async def init_app_state(
|
||||
state.vllm_config = vllm_config
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
supported_tasks = await engine_client \
|
||||
.get_supported_tasks() # type: ignore
|
||||
else:
|
||||
supported_tasks = model_config.supported_tasks
|
||||
supported_tasks = await engine_client.get_supported_tasks()
|
||||
|
||||
logger.info("Supported_tasks: %s", supported_tasks)
|
||||
|
||||
|
||||
@ -14,7 +14,6 @@ import torch
|
||||
from prometheus_client import start_http_server
|
||||
from tqdm import tqdm
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
|
||||
from vllm.engine.protocol import EngineClient
|
||||
@ -334,12 +333,7 @@ async def run_batch(
|
||||
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
supported_tasks = await engine_client \
|
||||
.get_supported_tasks() # type: ignore
|
||||
else:
|
||||
supported_tasks = model_config.supported_tasks
|
||||
|
||||
supported_tasks = await engine_client.get_supported_tasks()
|
||||
logger.info("Supported_tasks: %s", supported_tasks)
|
||||
|
||||
# Create the openai serving objects.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user