mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-20 05:55:01 +08:00
[V0 deprecation] Remove unreachable model_config.supported_tasks (#25642)
Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
3d940e2c3f
commit
f3d9099b44
@ -97,7 +97,6 @@ def test_auto_task(model_id, expected_runner_type, expected_convert_type,
|
|||||||
|
|
||||||
assert config.runner_type == expected_runner_type
|
assert config.runner_type == expected_runner_type
|
||||||
assert config.convert_type == expected_convert_type
|
assert config.convert_type == expected_convert_type
|
||||||
assert expected_task in config.supported_tasks
|
|
||||||
|
|
||||||
|
|
||||||
# Can remove once --task option is fully deprecated
|
# Can remove once --task option is fully deprecated
|
||||||
@ -120,7 +119,6 @@ def test_score_task(model_id, expected_runner_type, expected_convert_type,
|
|||||||
|
|
||||||
assert config.runner_type == expected_runner_type
|
assert config.runner_type == expected_runner_type
|
||||||
assert config.convert_type == expected_convert_type
|
assert config.convert_type == expected_convert_type
|
||||||
assert expected_task in config.supported_tasks
|
|
||||||
|
|
||||||
|
|
||||||
# Can remove once --task option is fully deprecated
|
# Can remove once --task option is fully deprecated
|
||||||
@ -137,7 +135,6 @@ def test_transcription_task(model_id, expected_runner_type,
|
|||||||
|
|
||||||
assert config.runner_type == expected_runner_type
|
assert config.runner_type == expected_runner_type
|
||||||
assert config.convert_type == expected_convert_type
|
assert config.convert_type == expected_convert_type
|
||||||
assert expected_task in config.supported_tasks
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|||||||
@ -14,7 +14,6 @@ from pydantic import (ConfigDict, SkipValidation, field_validator,
|
|||||||
model_validator)
|
model_validator)
|
||||||
from pydantic.dataclasses import dataclass
|
from pydantic.dataclasses import dataclass
|
||||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||||
from typing_extensions import assert_never
|
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode,
|
from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode,
|
||||||
@ -534,9 +533,6 @@ class ModelConfig:
|
|||||||
f"You can pass `--convert {convert_option} to adapt "
|
f"You can pass `--convert {convert_option} to adapt "
|
||||||
"it into a pooling model.")
|
"it into a pooling model.")
|
||||||
|
|
||||||
self.supported_tasks = self._get_supported_tasks(
|
|
||||||
architectures, self.runner_type, self.convert_type)
|
|
||||||
|
|
||||||
# Note: Initialize these attributes early because transformers fallback
|
# Note: Initialize these attributes early because transformers fallback
|
||||||
# may fail to load dynamic modules in child processes
|
# may fail to load dynamic modules in child processes
|
||||||
model_info, arch = registry.inspect_model_cls(architectures, self)
|
model_info, arch = registry.inspect_model_cls(architectures, self)
|
||||||
@ -834,27 +830,6 @@ class ModelConfig:
|
|||||||
|
|
||||||
return convert_type
|
return convert_type
|
||||||
|
|
||||||
def _get_supported_generation_tasks(
|
|
||||||
self,
|
|
||||||
architectures: list[str],
|
|
||||||
convert_type: ConvertType,
|
|
||||||
) -> list[_ResolvedTask]:
|
|
||||||
registry = self.registry
|
|
||||||
|
|
||||||
if registry.is_transcription_only_model(architectures, self):
|
|
||||||
return ["transcription"]
|
|
||||||
|
|
||||||
# TODO: Use get_supported_generation_tasks once V0 is removed
|
|
||||||
supported_tasks = list[_ResolvedTask]()
|
|
||||||
if (registry.is_text_generation_model(architectures, self)
|
|
||||||
or convert_type in _RUNNER_CONVERTS["generate"]):
|
|
||||||
supported_tasks.append("generate")
|
|
||||||
|
|
||||||
if registry.is_transcription_model(architectures, self):
|
|
||||||
supported_tasks.append("transcription")
|
|
||||||
|
|
||||||
return supported_tasks
|
|
||||||
|
|
||||||
def _get_default_pooling_task(
|
def _get_default_pooling_task(
|
||||||
self,
|
self,
|
||||||
architectures: list[str],
|
architectures: list[str],
|
||||||
@ -872,42 +847,6 @@ class ModelConfig:
|
|||||||
|
|
||||||
return "embed"
|
return "embed"
|
||||||
|
|
||||||
def _get_supported_pooling_tasks(
|
|
||||||
self,
|
|
||||||
architectures: list[str],
|
|
||||||
convert_type: ConvertType,
|
|
||||||
) -> list[_ResolvedTask]:
|
|
||||||
registry = self.registry
|
|
||||||
|
|
||||||
# TODO: Use get_supported_pooling_tasks once V0 is removed
|
|
||||||
supported_tasks = list[_ResolvedTask]()
|
|
||||||
if (registry.is_pooling_model(architectures, self)
|
|
||||||
or convert_type in _RUNNER_CONVERTS["pooling"]):
|
|
||||||
supported_tasks.append("encode")
|
|
||||||
|
|
||||||
extra_task = (self._get_default_pooling_task(architectures)
|
|
||||||
if convert_type == "none" else convert_type)
|
|
||||||
supported_tasks.append(extra_task)
|
|
||||||
|
|
||||||
return supported_tasks
|
|
||||||
|
|
||||||
def _get_supported_tasks(
|
|
||||||
self,
|
|
||||||
architectures: list[str],
|
|
||||||
runner_type: RunnerType,
|
|
||||||
convert_type: ConvertType,
|
|
||||||
) -> list[_ResolvedTask]:
|
|
||||||
if runner_type == "generate":
|
|
||||||
return self._get_supported_generation_tasks(
|
|
||||||
architectures, convert_type)
|
|
||||||
if runner_type == "pooling":
|
|
||||||
return self._get_supported_pooling_tasks(architectures,
|
|
||||||
convert_type)
|
|
||||||
if runner_type == "draft":
|
|
||||||
return ["draft"]
|
|
||||||
|
|
||||||
assert_never(runner_type)
|
|
||||||
|
|
||||||
def _parse_quant_hf_config(self, hf_config: PretrainedConfig):
|
def _parse_quant_hf_config(self, hf_config: PretrainedConfig):
|
||||||
quant_cfg = getattr(hf_config, "quantization_config", None)
|
quant_cfg = getattr(hf_config, "quantization_config", None)
|
||||||
if quant_cfg is None:
|
if quant_cfg is None:
|
||||||
|
|||||||
@ -16,6 +16,7 @@ from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
|
|||||||
from vllm.plugins.io_processors.interface import IOProcessor
|
from vllm.plugins.io_processors.interface import IOProcessor
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||||
|
from vllm.tasks import SupportedTask
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.utils import Device, collect_from_async_generator, random_uuid
|
from vllm.utils import Device, collect_from_async_generator, random_uuid
|
||||||
|
|
||||||
@ -326,3 +327,7 @@ class EngineClient(ABC):
|
|||||||
kwargs: Optional[dict] = None):
|
kwargs: Optional[dict] = None):
|
||||||
"""Perform a collective RPC call to the given path."""
|
"""Perform a collective RPC call to the given path."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||||
|
"""Get supported tasks"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|||||||
@ -1609,11 +1609,7 @@ async def init_app_state(
|
|||||||
state.vllm_config = vllm_config
|
state.vllm_config = vllm_config
|
||||||
model_config = vllm_config.model_config
|
model_config = vllm_config.model_config
|
||||||
|
|
||||||
if envs.VLLM_USE_V1:
|
supported_tasks = await engine_client.get_supported_tasks()
|
||||||
supported_tasks = await engine_client \
|
|
||||||
.get_supported_tasks() # type: ignore
|
|
||||||
else:
|
|
||||||
supported_tasks = model_config.supported_tasks
|
|
||||||
|
|
||||||
logger.info("Supported_tasks: %s", supported_tasks)
|
logger.info("Supported_tasks: %s", supported_tasks)
|
||||||
|
|
||||||
|
|||||||
@ -14,7 +14,6 @@ import torch
|
|||||||
from prometheus_client import start_http_server
|
from prometheus_client import start_http_server
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import vllm.envs as envs
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
|
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
|
||||||
from vllm.engine.protocol import EngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
@ -334,12 +333,7 @@ async def run_batch(
|
|||||||
|
|
||||||
model_config = vllm_config.model_config
|
model_config = vllm_config.model_config
|
||||||
|
|
||||||
if envs.VLLM_USE_V1:
|
supported_tasks = await engine_client.get_supported_tasks()
|
||||||
supported_tasks = await engine_client \
|
|
||||||
.get_supported_tasks() # type: ignore
|
|
||||||
else:
|
|
||||||
supported_tasks = model_config.supported_tasks
|
|
||||||
|
|
||||||
logger.info("Supported_tasks: %s", supported_tasks)
|
logger.info("Supported_tasks: %s", supported_tasks)
|
||||||
|
|
||||||
# Create the openai serving objects.
|
# Create the openai serving objects.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user