[Core] Support multiple tasks per model (#20771)

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Nicolò Lucchesi 2025-07-13 04:40:11 +02:00 committed by GitHub
parent c1acd6d7d4
commit 020f58abcd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 279 additions and 148 deletions

View File

@ -54,7 +54,7 @@ def test_get_field():
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"),
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "classify"),
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "reward"),
("openai/whisper-small", "transcription", "transcription"),
("openai/whisper-small", "generate", "transcription"),
],
)
def test_auto_task(model_id, expected_runner_type, expected_task):
@ -69,7 +69,11 @@ def test_auto_task(model_id, expected_runner_type, expected_task):
)
assert config.runner_type == expected_runner_type
assert config.task == expected_task
if config.runner_type == "pooling":
assert config.task == expected_task
else:
assert expected_task in config.supported_tasks
@pytest.mark.parametrize(
@ -98,11 +102,50 @@ def test_score_task(model_id, expected_runner_type, expected_task):
assert config.task == expected_task
@pytest.mark.parametrize(("model_id", "expected_runner_type", "expected_task"),
[
("Qwen/Qwen2.5-1.5B-Instruct", "draft", "auto"),
])
def test_draft_task(model_id, expected_runner_type, expected_task):
config = ModelConfig(
model_id,
runner="draft",
tokenizer=model_id,
seed=0,
dtype="float16",
)
assert config.runner_type == expected_runner_type
assert config.task == expected_task
@pytest.mark.parametrize(
("model_id", "expected_runner_type", "expected_task"),
[
("openai/whisper-small", "generate", "transcription"),
],
)
def test_transcription_task(model_id, expected_runner_type, expected_task):
config = ModelConfig(
model_id,
task="transcription",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
)
assert config.runner_type == expected_runner_type
assert config.task == expected_task
@pytest.mark.parametrize(("model_id", "bad_task"), [
("Qwen/Qwen2.5-Math-RM-72B", "generate"),
("Qwen/Qwen3-0.6B", "transcription"),
])
def test_incorrect_task(model_id, bad_task):
with pytest.raises(ValueError, match=r"does not support the .* task"):
with pytest.raises(ValueError, match=r"does not support task=.*"):
ModelConfig(
model_id,
task=bad_task,

View File

@ -91,24 +91,19 @@ logger = init_logger(__name__)
ConfigT = TypeVar("ConfigT", bound=ConfigType)
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
"score", "reward", "transcription"]
"score", "reward", "transcription", "draft"]
_ResolvedTask = Literal["generate", "embed", "classify", "reward", "draft",
"transcription"]
_ResolvedTask = Literal["generate", "transcription", "pooling", "embed",
"classify", "reward", "draft"]
RunnerType = Literal["generate", "pooling", "draft", "transcription"]
RunnerOption = Literal["auto", "generate", "pooling", "draft"]
RunnerType = Literal["generate", "pooling", "draft"]
_RUNNER_TASKS: dict[RunnerType, list[_ResolvedTask]] = {
"generate": ["generate"],
"pooling": ["embed", "classify", "reward"],
"draft": ["draft"],
"transcription": ["transcription"],
}
_TASK_RUNNER: dict[_ResolvedTask, RunnerType] = {
task: runner
for runner, tasks in _RUNNER_TASKS.items()
for task in tasks
"generate": ["generate", "transcription"],
"pooling": ["pooling", "embed", "classify", "reward"],
"draft": [],
}
@ -234,11 +229,14 @@ class ModelConfig:
"""Name or path of the Hugging Face model to use. It is also used as the
content for `model_name` tag in metrics output when `served_model_name` is
not specified."""
task: Literal[TaskOption, Literal["draft"]] = "auto"
"""The task to use the model for. Each vLLM instance only supports one
task, even if the same model can be used for multiple tasks. When the model
only supports one task, "auto" can be used to select it; otherwise, you
must specify explicitly which task to use."""
runner: RunnerOption = "auto"
"""The type of model runner to use. Each vLLM instance only supports one
model runner, even if the same model can be used for multiple types."""
task: TaskOption = "auto"
"""The task to use the model for. If the model supports more than one
model runner, this is used to select which model runner to run.
Note that the model may support other tasks using the same model runner."""
tokenizer: SkipValidation[str] = None # type: ignore
"""Name or path of the Hugging Face tokenizer to use. If unspecified, model
name or path will be used."""
@ -553,10 +551,41 @@ class ModelConfig:
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, hf_token=self.hf_token, revision=self.revision)
supported_tasks, task = self._resolve_task(self.task)
self.supported_tasks = supported_tasks
self.task = task
if self.task in ("draft", "generate"):
# For pooling models, self.task is used to indicate the
# user-selected task
if self.task == "score":
if self.registry.is_cross_encoder_model(self.architectures):
self.task = "classify"
else:
self.task = "embed"
elif self.task == "embedding":
msg = ("The 'embedding' task has been renamed to 'embed', please "
"use the new name. The old name will be removed in v1.0.")
warnings.warn(msg, DeprecationWarning, stacklevel=2)
self.task = "embed"
all_supported_tasks = self._get_supported_tasks(self.task)
logger.debug("Tasks supported by runner type: %s", all_supported_tasks)
supported_runner_types = self._get_supported_runner_types(
all_supported_tasks)
runner_type = self._resolve_runner(self.runner, self.task,
supported_runner_types,
all_supported_tasks)
logger.debug("Selected runner type: %s", runner_type)
# For pooling models, self.task is used to indicate the
# user-selected task
if runner_type == "pooling" and self.task == "auto":
selected_task = all_supported_tasks[runner_type][-1]
assert selected_task != "pooling"
self.task = selected_task
self.supported_runner_types = supported_runner_types
self.runner_type = runner_type
self.supported_tasks = all_supported_tasks[runner_type]
if self.runner_type in ("draft",
"generate") and self.task != "transcription":
self.truncation_side = "left"
else:
self.truncation_side = "right"
@ -780,11 +809,10 @@ class ModelConfig:
f"one of {get_args(TokenizerMode)}.")
self.tokenizer_mode = tokenizer_mode
def _get_preferred_task(
def _get_preferred_pooling_task(
self,
architectures: list[str],
supported_tasks: set[_ResolvedTask],
) -> Optional[_ResolvedTask]:
) -> _ResolvedTask:
model_id = self.model
if get_pooling_config(model_id, self.revision):
return "embed"
@ -795,92 +823,136 @@ class ModelConfig:
suffix_to_preferred_task: list[tuple[str, _ResolvedTask]] = [
# Other models follow this pattern
("ForCausalLM", "generate"),
("ForConditionalGeneration", "generate"),
("ForSequenceClassification", "classify"),
("ChatModel", "generate"),
("LMHeadModel", "generate"),
("EmbeddingModel", "embed"),
("RewardModel", "reward"),
]
_, arch = self.registry.inspect_model_cls(architectures)
for suffix, pref_task in suffix_to_preferred_task:
if arch.endswith(suffix) and pref_task in supported_tasks:
if arch.endswith(suffix):
return pref_task
return None
return "embed"
def _resolve_task(
def _get_supported_generation_tasks(
self,
task_option: Literal[TaskOption, Literal["draft"]],
) -> tuple[set[_ResolvedTask], _ResolvedTask]:
if task_option == "draft":
return {"draft"}, "draft"
task_option: TaskOption,
) -> list[_ResolvedTask]:
registry = self.registry
architectures = self.architectures
runner_support: dict[RunnerType, bool] = {
# NOTE: Listed from highest to lowest priority,
# in case the model supports multiple of them
"transcription": registry.is_transcription_model(architectures),
"generate": registry.is_text_generation_model(architectures),
"pooling": registry.is_pooling_model(architectures),
if registry.is_transcription_only_model(architectures):
return ["transcription"]
supported_tasks = list[_ResolvedTask]()
if registry.is_text_generation_model(architectures):
supported_tasks.append("generate")
if registry.is_transcription_model(architectures):
supported_tasks.append("transcription")
return supported_tasks
def _get_supported_pooling_tasks(
self,
task_option: TaskOption,
) -> list[_ResolvedTask]:
registry = self.registry
architectures = self.architectures
supported_tasks = list[_ResolvedTask]()
if registry.is_pooling_model(architectures):
supported_tasks.append("pooling")
# For now, users must specify the task (other than "pooling")
# to use for pooling models
if task_option == "auto":
preferred_task = self._get_preferred_pooling_task(
architectures)
supported_tasks.append(preferred_task)
elif task_option in _RUNNER_TASKS["pooling"]:
supported_tasks.append(cast(_ResolvedTask, task_option))
return supported_tasks
def _get_supported_tasks(
self,
task_option: TaskOption,
) -> dict[RunnerType, list[_ResolvedTask]]:
return {
"generate": self._get_supported_generation_tasks(task_option),
"pooling": self._get_supported_pooling_tasks(task_option),
"draft": ["draft"]
}
supported_runner_types_lst: list[RunnerType] = [
runner_type
for runner_type, is_supported in runner_support.items()
if is_supported
]
supported_tasks_lst: list[_ResolvedTask] = [
task for runner_type in supported_runner_types_lst
for task in _RUNNER_TASKS[runner_type]
]
supported_tasks = set(supported_tasks_lst)
def _get_supported_runner_types(
self,
supported_tasks: dict[RunnerType, list[_ResolvedTask]],
) -> set[RunnerType]:
return {
runner
for runner, runner_tasks in supported_tasks.items()
if len(runner_tasks) > 0
}
if task_option == "auto":
selected_task = next(iter(supported_tasks_lst))
def _resolve_runner(
self,
runner_option: RunnerOption,
task_option: TaskOption,
supported_runner_types: set[RunnerType],
supported_tasks: dict[RunnerType, list[_ResolvedTask]],
) -> RunnerType:
if not supported_runner_types:
raise ValueError("This model does not support any model runners!")
if len(supported_tasks_lst) > 1:
preferred_task = self._get_preferred_task(
architectures, supported_tasks)
if preferred_task is not None:
selected_task = preferred_task
if runner_option != "auto":
if runner_option not in supported_runner_types:
raise ValueError(
f"This model does not support runner={runner_option!r}. "
f"Available runners: {supported_runner_types}")
logger.info(
"This model supports multiple tasks: %s. "
"Defaulting to '%s'.", supported_tasks, selected_task)
else:
if task_option == "score":
if not runner_support["pooling"]:
msg = (f"This model does not support the '{task_option}' "
f"task. Supported tasks: {supported_tasks}")
raise ValueError(msg)
if self.registry.is_cross_encoder_model(architectures):
task_option = "classify"
else:
task_option = "embed"
return runner_option
if task_option != "auto":
for runner, runner_tasks in supported_tasks.items():
if task_option in runner_tasks:
return runner
else:
# Aliases
if task_option == "embedding":
msg = ("The 'embedding' task has been renamed to "
"'embed', please use the new name. The old name "
"will be removed in v1.0.")
warnings.warn(msg, DeprecationWarning, stacklevel=2)
task_runner: RunnerType = next(
runner for runner, tasks in _RUNNER_TASKS.items()
if task_option in tasks)
raise ValueError(
f"This model does not support task={task_option!r}. "
f"Available tasks for runner={task_runner!r}: "
f"{supported_tasks[task_runner]}")
task_option = "embed"
suffix_to_preferred_runner: list[tuple[str, RunnerType]] = [
("ForCausalLM", "generate"),
("ForConditionalGeneration", "generate"),
("ChatModel", "generate"),
("LMHeadModel", "generate"),
("ForSequenceClassification", "pooling"),
("EmbeddingModel", "pooling"),
("RewardModel", "pooling"),
]
_, arch = self.registry.inspect_model_cls(self.architectures)
if task_option not in supported_tasks:
msg = (
f"This model does not support the '{task_option}' task. "
f"Supported tasks: {supported_tasks}")
raise ValueError(msg)
for suffix, pref_runner in suffix_to_preferred_runner:
if arch.endswith(suffix) and pref_runner in supported_runner_types:
return pref_runner
selected_task = task_option
if "classify" in supported_tasks.get("pooling", []):
# When multiple pooling tasks are present, default to
# pooling (eg cross-encoder) for non-standard architectures.
return "pooling"
if "generate" in supported_runner_types:
return "generate"
if "pooling" in supported_runner_types:
return "pooling"
return supported_tasks, selected_task
raise AssertionError("This line should not be reached")
def _parse_quant_hf_config(self):
quant_cfg = getattr(self.hf_config, "quantization_config", None)
@ -1449,14 +1521,6 @@ class ModelConfig:
def use_mla(self) -> bool:
return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE
@property
def supported_runner_types(self) -> set[RunnerType]:
return {_TASK_RUNNER[task] for task in self.supported_tasks}
@property
def runner_type(self) -> RunnerType:
return _TASK_RUNNER[cast(_ResolvedTask, self.task)]
@property
def is_v1_compatible(self) -> bool:
architectures = getattr(self.hf_config, "architectures", [])
@ -2694,7 +2758,7 @@ class SpeculativeConfig:
if self.model is not None:
self.draft_model_config = ModelConfig(
model=self.model,
task="draft",
runner="draft",
tokenizer=self.target_model_config.tokenizer,
tokenizer_mode=self.target_model_config.tokenizer_mode,
trust_remote_code=self.target_model_config.

View File

@ -454,20 +454,19 @@ class LLM:
considered legacy and may be deprecated in the future. You should
instead pass them via the `inputs` parameter.
"""
runner_type = self.llm_engine.model_config.runner_type
if runner_type not in ["generate", "transcription"]:
model_config = self.llm_engine.model_config
runner_type = model_config.runner_type
if runner_type != "generate":
messages = [
"LLM.generate() is only supported for (conditional) generation "
"models (XForCausalLM, XForConditionalGeneration).",
"LLM.generate() is only supported for generative models."
]
supported_runner_types = self.llm_engine.model_config \
.supported_runner_types
if "generate" in supported_runner_types:
if "generate" in model_config.supported_runner_types:
messages.append(
"Your model supports the 'generate' runner, but is "
f"currently initialized for the '{runner_type}' runner. "
"Please initialize vLLM using `--task generate`.")
"Please initialize vLLM using `--task generate` or "
"`--task transcription`.")
raise ValueError(" ".join(messages))
@ -1091,13 +1090,12 @@ class LLM:
considered legacy and may be deprecated in the future. You should
instead pass them via the `inputs` parameter.
"""
runner_type = self.llm_engine.model_config.runner_type
model_config = self.llm_engine.model_config
runner_type = model_config.runner_type
if runner_type != "pooling":
messages = ["LLM.encode() is only supported for pooling models."]
supported_runner_types = self.llm_engine.model_config \
.supported_runner_types
if "pooling" in supported_runner_types:
if "pooling" in model_config.supported_runner_types:
messages.append(
"Your model supports the 'pooling' runner, but is "
f"currently initialized for the '{runner_type}' runner. "
@ -1119,13 +1117,13 @@ class LLM:
# Use default pooling params.
pooling_params = PoolingParams()
elif isinstance(pooling_params, PoolingParams):
pooling_params.verify(self.llm_engine.model_config)
pooling_params.verify(model_config)
else:
for pooling_param in pooling_params:
pooling_param.verify(self.llm_engine.model_config)
pooling_param.verify(model_config)
tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(self.llm_engine.model_config.max_model_len,
tokenization_kwargs = dict[str, Any]()
_validate_truncation_size(model_config.max_model_len,
truncate_prompt_tokens, tokenization_kwargs)
self._validate_and_add_requests(
@ -1178,9 +1176,10 @@ class LLM:
A list of `EmbeddingRequestOutput` objects containing the
embedding vectors in the same order as the input prompts.
"""
if self.llm_engine.model_config.task != "embed":
raise ValueError(
"Embedding API is only enabled for `--task embed`")
model_config = self.llm_engine.model_config
if "embed" not in model_config.supported_tasks:
raise ValueError("Embedding API is not supported by this model. "
"Please set `--task embed`.")
items = self.encode(prompts,
truncate_prompt_tokens=truncate_prompt_tokens,
@ -1223,9 +1222,11 @@ class LLM:
A list of `ClassificationRequestOutput` objects containing the
embedding vectors in the same order as the input prompts.
"""
if self.llm_engine.model_config.task != "classify":
model_config = self.llm_engine.model_config
if "classify" not in model_config.supported_tasks:
raise ValueError(
"Classification API is only enabled for `--task classify`")
"Classification API is not supported by this model. "
"Please set `--task classify`.")
items = self.encode(prompts,
use_tqdm=use_tqdm,
@ -1392,13 +1393,12 @@ class LLM:
A list of `ScoringRequestOutput` objects containing the
generated scores in the same order as the input prompts.
"""
runner_type = self.llm_engine.model_config.runner_type
model_config = self.llm_engine.model_config
runner_type = model_config.runner_type
if runner_type != "pooling":
messages = ["LLM.score() is only supported for pooling models."]
supported_runner_types = self.llm_engine.model_config \
.supported_runner_types
if "pooling" in supported_runner_types:
if "pooling" in model_config.supported_runner_types:
messages.append(
"Your model supports the 'pooling' runner, but is "
f"currently initialized for the '{runner_type}' runner. "
@ -1407,12 +1407,13 @@ class LLM:
raise ValueError(" ".join(messages))
if self.llm_engine.model_config.task not in ("embed", "classify"):
raise ValueError("Score API is only enabled for "
"`--task embed or --task classify`.")
if all(t not in model_config.supported_tasks
for t in ("embed", "classify")):
raise ValueError("Score API is not supported by this model. "
"Please set `--task embed` or `--task classify`.")
if (self.llm_engine.model_config.task == "classify"
and self.llm_engine.model_config.hf_config.num_labels != 1):
if (model_config.task == "classify"
and getattr(model_config.hf_config, "num_labels", 0) != 1):
raise ValueError("Score API is only enabled for num_labels == 1.")
# the tokenizer for models such as

View File

@ -1520,7 +1520,7 @@ async def init_app_state(
reasoning_parser=args.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
) if model_config.runner_type == "generate" else None
) if "generate" in model_config.supported_tasks else None
state.openai_serving_chat = OpenAIServingChat(
engine_client,
model_config,
@ -1537,7 +1537,7 @@ async def init_app_state(
reasoning_parser=args.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
) if model_config.runner_type == "generate" else None
) if "generate" in model_config.supported_tasks else None
state.openai_serving_completion = OpenAIServingCompletion(
engine_client,
model_config,
@ -1545,7 +1545,7 @@ async def init_app_state(
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_force_include_usage=args.enable_force_include_usage,
) if model_config.runner_type == "generate" else None
) if "generate" in model_config.supported_tasks else None
state.openai_serving_pooling = OpenAIServingPooling(
engine_client,
model_config,
@ -1553,7 +1553,7 @@ async def init_app_state(
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
) if model_config.runner_type == "pooling" else None
) if "pooling" in model_config.supported_tasks else None
state.openai_serving_embedding = OpenAIServingEmbedding(
engine_client,
model_config,
@ -1561,22 +1561,24 @@ async def init_app_state(
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
) if model_config.task == "embed" else None
) if "embed" in model_config.supported_tasks else None
state.openai_serving_classification = ServingClassification(
engine_client,
model_config,
state.openai_serving_models,
request_logger=request_logger,
) if model_config.task == "classify" else None
) if "classify" in model_config.supported_tasks else None
enable_serving_reranking = (model_config.task == "classify" and getattr(
model_config.hf_config, "num_labels", 0) == 1)
enable_serving_reranking = ("classify" in model_config.supported_tasks
and getattr(model_config.hf_config,
"num_labels", 0) == 1)
state.openai_serving_scores = ServingScores(
engine_client,
model_config,
state.openai_serving_models,
request_logger=request_logger) if (
model_config.task == "embed" or enable_serving_reranking) else None
request_logger=request_logger,
) if ("embed" in model_config.supported_tasks
or enable_serving_reranking) else None
state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client,
@ -1591,13 +1593,13 @@ async def init_app_state(
model_config,
state.openai_serving_models,
request_logger=request_logger,
) if model_config.runner_type == "transcription" else None
) if "transcription" in model_config.supported_tasks else None
state.openai_serving_translation = OpenAIServingTranslation(
engine_client,
model_config,
state.openai_serving_models,
request_logger=request_logger,
) if model_config.runner_type == "transcription" else None
) if "transcription" in model_config.supported_tasks else None
state.task = model_config.task
state.enable_server_load_tracking = args.enable_server_load_tracking

View File

@ -348,7 +348,7 @@ async def main(args):
chat_template=None,
chat_template_content_format="auto",
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
) if model_config.runner_type == "generate" else None
) if "generate" in model_config.supported_tasks else None
openai_serving_embedding = OpenAIServingEmbedding(
engine,
model_config,
@ -356,17 +356,19 @@ async def main(args):
request_logger=request_logger,
chat_template=None,
chat_template_content_format="auto",
) if model_config.task == "embed" else None
) if "embed" in model_config.supported_tasks else None
enable_serving_reranking = (model_config.task == "classify" and getattr(
model_config.hf_config, "num_labels", 0) == 1)
enable_serving_reranking = ("classify" in model_config.supported_tasks
and getattr(model_config.hf_config,
"num_labels", 0) == 1)
openai_serving_scores = (ServingScores(
openai_serving_scores = ServingScores(
engine,
model_config,
openai_serving_models,
request_logger=request_logger,
) if (model_config.task == "embed" or enable_serving_reranking) else None)
) if ("embed" in model_config.supported_tasks
or enable_serving_reranking) else None
tracker = BatchProgressTracker()
logger.info("Reading batch from %s...", args.input_file)

View File

@ -694,6 +694,12 @@ class SupportsTranscription(Protocol):
supports_transcription: ClassVar[Literal[True]] = True
supports_transcription_only: ClassVar[bool] = False
"""
Transcription models can opt out of text generation by setting this to
`True`.
"""
@classmethod
def get_generation_prompt(cls, audio: np.ndarray,
stt_config: SpeechToTextConfig, language: str,

View File

@ -284,6 +284,7 @@ class _ModelInfo:
is_hybrid: bool
has_noops: bool
supports_transcription: bool
supports_transcription_only: bool
supports_v0_only: bool
@staticmethod
@ -299,6 +300,8 @@ class _ModelInfo:
is_attention_free=is_attention_free(model),
is_hybrid=is_hybrid(model),
supports_transcription=supports_transcription(model),
supports_transcription_only=(supports_transcription(model) and
model.supports_transcription_only),
supports_v0_only=supports_v0_only(model),
has_noops=has_noops(model),
)
@ -573,6 +576,13 @@ class _ModelRegistry:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.supports_transcription
def is_transcription_only_model(
self,
architectures: Union[str, list[str]],
) -> bool:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.supports_transcription_only
def is_v1_compatible(
self,
architectures: Union[str, list[str]],

View File

@ -772,6 +772,9 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
".fc2.": ".mlp.fc2."
})
# Whisper only supports audio-conditioned generation.
supports_transcription_only = True
@classmethod
def validate_language(cls, language: str) -> bool:
if language in ISO639_1_SUPPORTED_LANGS: