diff --git a/tests/test_config.py b/tests/test_config.py index 6ed7ef9e6a40d..a160b08f28aa5 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -54,7 +54,7 @@ def test_get_field(): ("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"), ("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "classify"), ("Qwen/Qwen2.5-Math-RM-72B", "pooling", "reward"), - ("openai/whisper-small", "transcription", "transcription"), + ("openai/whisper-small", "generate", "transcription"), ], ) def test_auto_task(model_id, expected_runner_type, expected_task): @@ -69,7 +69,11 @@ def test_auto_task(model_id, expected_runner_type, expected_task): ) assert config.runner_type == expected_runner_type - assert config.task == expected_task + + if config.runner_type == "pooling": + assert config.task == expected_task + else: + assert expected_task in config.supported_tasks @pytest.mark.parametrize( @@ -98,11 +102,50 @@ def test_score_task(model_id, expected_runner_type, expected_task): assert config.task == expected_task +@pytest.mark.parametrize(("model_id", "expected_runner_type", "expected_task"), + [ + ("Qwen/Qwen2.5-1.5B-Instruct", "draft", "auto"), + ]) +def test_draft_task(model_id, expected_runner_type, expected_task): + config = ModelConfig( + model_id, + runner="draft", + tokenizer=model_id, + seed=0, + dtype="float16", + ) + + assert config.runner_type == expected_runner_type + assert config.task == expected_task + + +@pytest.mark.parametrize( + ("model_id", "expected_runner_type", "expected_task"), + [ + ("openai/whisper-small", "generate", "transcription"), + ], +) +def test_transcription_task(model_id, expected_runner_type, expected_task): + config = ModelConfig( + model_id, + task="transcription", + tokenizer=model_id, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + ) + + assert config.runner_type == expected_runner_type + assert config.task == expected_task + + @pytest.mark.parametrize(("model_id", "bad_task"), [ ("Qwen/Qwen2.5-Math-RM-72B", "generate"), + ("Qwen/Qwen3-0.6B", "transcription"), ]) def test_incorrect_task(model_id, bad_task): - with pytest.raises(ValueError, match=r"does not support the .* task"): + with pytest.raises(ValueError, match=r"does not support task=.*"): ModelConfig( model_id, task=bad_task, diff --git a/vllm/config.py b/vllm/config.py index cfd7b9e336704..ddaff0710a3b8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -91,24 +91,19 @@ logger = init_logger(__name__) ConfigT = TypeVar("ConfigT", bound=ConfigType) TaskOption = Literal["auto", "generate", "embedding", "embed", "classify", - "score", "reward", "transcription"] + "score", "reward", "transcription", "draft"] -_ResolvedTask = Literal["generate", "embed", "classify", "reward", "draft", - "transcription"] +_ResolvedTask = Literal["generate", "transcription", "pooling", "embed", + "classify", "reward", "draft"] -RunnerType = Literal["generate", "pooling", "draft", "transcription"] +RunnerOption = Literal["auto", "generate", "pooling", "draft"] + +RunnerType = Literal["generate", "pooling", "draft"] _RUNNER_TASKS: dict[RunnerType, list[_ResolvedTask]] = { - "generate": ["generate"], - "pooling": ["embed", "classify", "reward"], - "draft": ["draft"], - "transcription": ["transcription"], -} - -_TASK_RUNNER: dict[_ResolvedTask, RunnerType] = { - task: runner - for runner, tasks in _RUNNER_TASKS.items() - for task in tasks + "generate": ["generate", "transcription"], + "pooling": ["pooling", "embed", "classify", "reward"], + "draft": [], } @@ -234,11 +229,14 @@ class ModelConfig: """Name or path of the Hugging Face model to use. It is also used as the content for `model_name` tag in metrics output when `served_model_name` is not specified.""" - task: Literal[TaskOption, Literal["draft"]] = "auto" - """The task to use the model for. Each vLLM instance only supports one - task, even if the same model can be used for multiple tasks. When the model - only supports one task, "auto" can be used to select it; otherwise, you - must specify explicitly which task to use.""" + runner: RunnerOption = "auto" + """The type of model runner to use. Each vLLM instance only supports one + model runner, even if the same model can be used for multiple types.""" + task: TaskOption = "auto" + """The task to use the model for. If the model supports more than one + model runner, this is used to select which model runner to run. + + Note that the model may support other tasks using the same model runner.""" tokenizer: SkipValidation[str] = None # type: ignore """Name or path of the Hugging Face tokenizer to use. If unspecified, model name or path will be used.""" @@ -553,10 +551,41 @@ class ModelConfig: self.hf_image_processor_config = get_hf_image_processor_config( self.model, hf_token=self.hf_token, revision=self.revision) - supported_tasks, task = self._resolve_task(self.task) - self.supported_tasks = supported_tasks - self.task = task - if self.task in ("draft", "generate"): + # For pooling models, self.task is used to indicate the + # user-selected task + if self.task == "score": + if self.registry.is_cross_encoder_model(self.architectures): + self.task = "classify" + else: + self.task = "embed" + elif self.task == "embedding": + msg = ("The 'embedding' task has been renamed to 'embed', please " + "use the new name. The old name will be removed in v1.0.") + warnings.warn(msg, DeprecationWarning, stacklevel=2) + + self.task = "embed" + + all_supported_tasks = self._get_supported_tasks(self.task) + logger.debug("Tasks supported by runner type: %s", all_supported_tasks) + supported_runner_types = self._get_supported_runner_types( + all_supported_tasks) + runner_type = self._resolve_runner(self.runner, self.task, + supported_runner_types, + all_supported_tasks) + + logger.debug("Selected runner type: %s", runner_type) + # For pooling models, self.task is used to indicate the + # user-selected task + if runner_type == "pooling" and self.task == "auto": + selected_task = all_supported_tasks[runner_type][-1] + assert selected_task != "pooling" + self.task = selected_task + self.supported_runner_types = supported_runner_types + self.runner_type = runner_type + self.supported_tasks = all_supported_tasks[runner_type] + + if self.runner_type in ("draft", + "generate") and self.task != "transcription": self.truncation_side = "left" else: self.truncation_side = "right" @@ -780,11 +809,10 @@ class ModelConfig: f"one of {get_args(TokenizerMode)}.") self.tokenizer_mode = tokenizer_mode - def _get_preferred_task( + def _get_preferred_pooling_task( self, architectures: list[str], - supported_tasks: set[_ResolvedTask], - ) -> Optional[_ResolvedTask]: + ) -> _ResolvedTask: model_id = self.model if get_pooling_config(model_id, self.revision): return "embed" @@ -795,92 +823,136 @@ class ModelConfig: suffix_to_preferred_task: list[tuple[str, _ResolvedTask]] = [ # Other models follow this pattern - ("ForCausalLM", "generate"), - ("ForConditionalGeneration", "generate"), ("ForSequenceClassification", "classify"), - ("ChatModel", "generate"), - ("LMHeadModel", "generate"), ("EmbeddingModel", "embed"), ("RewardModel", "reward"), ] _, arch = self.registry.inspect_model_cls(architectures) for suffix, pref_task in suffix_to_preferred_task: - if arch.endswith(suffix) and pref_task in supported_tasks: + if arch.endswith(suffix): return pref_task - return None + return "embed" - def _resolve_task( + def _get_supported_generation_tasks( self, - task_option: Literal[TaskOption, Literal["draft"]], - ) -> tuple[set[_ResolvedTask], _ResolvedTask]: - if task_option == "draft": - return {"draft"}, "draft" - + task_option: TaskOption, + ) -> list[_ResolvedTask]: registry = self.registry architectures = self.architectures - runner_support: dict[RunnerType, bool] = { - # NOTE: Listed from highest to lowest priority, - # in case the model supports multiple of them - "transcription": registry.is_transcription_model(architectures), - "generate": registry.is_text_generation_model(architectures), - "pooling": registry.is_pooling_model(architectures), + if registry.is_transcription_only_model(architectures): + return ["transcription"] + + supported_tasks = list[_ResolvedTask]() + if registry.is_text_generation_model(architectures): + supported_tasks.append("generate") + + if registry.is_transcription_model(architectures): + supported_tasks.append("transcription") + + return supported_tasks + + def _get_supported_pooling_tasks( + self, + task_option: TaskOption, + ) -> list[_ResolvedTask]: + registry = self.registry + architectures = self.architectures + + supported_tasks = list[_ResolvedTask]() + if registry.is_pooling_model(architectures): + supported_tasks.append("pooling") + + # For now, users must specify the task (other than "pooling") + # to use for pooling models + if task_option == "auto": + preferred_task = self._get_preferred_pooling_task( + architectures) + + supported_tasks.append(preferred_task) + elif task_option in _RUNNER_TASKS["pooling"]: + supported_tasks.append(cast(_ResolvedTask, task_option)) + + return supported_tasks + + def _get_supported_tasks( + self, + task_option: TaskOption, + ) -> dict[RunnerType, list[_ResolvedTask]]: + return { + "generate": self._get_supported_generation_tasks(task_option), + "pooling": self._get_supported_pooling_tasks(task_option), + "draft": ["draft"] } - supported_runner_types_lst: list[RunnerType] = [ - runner_type - for runner_type, is_supported in runner_support.items() - if is_supported - ] - supported_tasks_lst: list[_ResolvedTask] = [ - task for runner_type in supported_runner_types_lst - for task in _RUNNER_TASKS[runner_type] - ] - supported_tasks = set(supported_tasks_lst) + def _get_supported_runner_types( + self, + supported_tasks: dict[RunnerType, list[_ResolvedTask]], + ) -> set[RunnerType]: + return { + runner + for runner, runner_tasks in supported_tasks.items() + if len(runner_tasks) > 0 + } - if task_option == "auto": - selected_task = next(iter(supported_tasks_lst)) + def _resolve_runner( + self, + runner_option: RunnerOption, + task_option: TaskOption, + supported_runner_types: set[RunnerType], + supported_tasks: dict[RunnerType, list[_ResolvedTask]], + ) -> RunnerType: + if not supported_runner_types: + raise ValueError("This model does not support any model runners!") - if len(supported_tasks_lst) > 1: - preferred_task = self._get_preferred_task( - architectures, supported_tasks) - if preferred_task is not None: - selected_task = preferred_task + if runner_option != "auto": + if runner_option not in supported_runner_types: + raise ValueError( + f"This model does not support runner={runner_option!r}. " + f"Available runners: {supported_runner_types}") - logger.info( - "This model supports multiple tasks: %s. " - "Defaulting to '%s'.", supported_tasks, selected_task) - else: - if task_option == "score": - if not runner_support["pooling"]: - msg = (f"This model does not support the '{task_option}' " - f"task. Supported tasks: {supported_tasks}") - raise ValueError(msg) - if self.registry.is_cross_encoder_model(architectures): - task_option = "classify" - else: - task_option = "embed" + return runner_option + + if task_option != "auto": + for runner, runner_tasks in supported_tasks.items(): + if task_option in runner_tasks: + return runner else: - # Aliases - if task_option == "embedding": - msg = ("The 'embedding' task has been renamed to " - "'embed', please use the new name. The old name " - "will be removed in v1.0.") - warnings.warn(msg, DeprecationWarning, stacklevel=2) + task_runner: RunnerType = next( + runner for runner, tasks in _RUNNER_TASKS.items() + if task_option in tasks) + raise ValueError( + f"This model does not support task={task_option!r}. " + f"Available tasks for runner={task_runner!r}: " + f"{supported_tasks[task_runner]}") - task_option = "embed" + suffix_to_preferred_runner: list[tuple[str, RunnerType]] = [ + ("ForCausalLM", "generate"), + ("ForConditionalGeneration", "generate"), + ("ChatModel", "generate"), + ("LMHeadModel", "generate"), + ("ForSequenceClassification", "pooling"), + ("EmbeddingModel", "pooling"), + ("RewardModel", "pooling"), + ] + _, arch = self.registry.inspect_model_cls(self.architectures) - if task_option not in supported_tasks: - msg = ( - f"This model does not support the '{task_option}' task. " - f"Supported tasks: {supported_tasks}") - raise ValueError(msg) + for suffix, pref_runner in suffix_to_preferred_runner: + if arch.endswith(suffix) and pref_runner in supported_runner_types: + return pref_runner - selected_task = task_option + if "classify" in supported_tasks.get("pooling", []): + # When multiple pooling tasks are present, default to + # pooling (eg cross-encoder) for non-standard architectures. + return "pooling" + if "generate" in supported_runner_types: + return "generate" + if "pooling" in supported_runner_types: + return "pooling" - return supported_tasks, selected_task + raise AssertionError("This line should not be reached") def _parse_quant_hf_config(self): quant_cfg = getattr(self.hf_config, "quantization_config", None) @@ -1449,14 +1521,6 @@ class ModelConfig: def use_mla(self) -> bool: return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE - @property - def supported_runner_types(self) -> set[RunnerType]: - return {_TASK_RUNNER[task] for task in self.supported_tasks} - - @property - def runner_type(self) -> RunnerType: - return _TASK_RUNNER[cast(_ResolvedTask, self.task)] - @property def is_v1_compatible(self) -> bool: architectures = getattr(self.hf_config, "architectures", []) @@ -2694,7 +2758,7 @@ class SpeculativeConfig: if self.model is not None: self.draft_model_config = ModelConfig( model=self.model, - task="draft", + runner="draft", tokenizer=self.target_model_config.tokenizer, tokenizer_mode=self.target_model_config.tokenizer_mode, trust_remote_code=self.target_model_config. diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index c60a566f585d9..e7398ecc23c8f 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -454,20 +454,19 @@ class LLM: considered legacy and may be deprecated in the future. You should instead pass them via the `inputs` parameter. """ - runner_type = self.llm_engine.model_config.runner_type - if runner_type not in ["generate", "transcription"]: + model_config = self.llm_engine.model_config + runner_type = model_config.runner_type + if runner_type != "generate": messages = [ - "LLM.generate() is only supported for (conditional) generation " - "models (XForCausalLM, XForConditionalGeneration).", + "LLM.generate() is only supported for generative models." ] - supported_runner_types = self.llm_engine.model_config \ - .supported_runner_types - if "generate" in supported_runner_types: + if "generate" in model_config.supported_runner_types: messages.append( "Your model supports the 'generate' runner, but is " f"currently initialized for the '{runner_type}' runner. " - "Please initialize vLLM using `--task generate`.") + "Please initialize vLLM using `--task generate` or " + "`--task transcription`.") raise ValueError(" ".join(messages)) @@ -1091,13 +1090,12 @@ class LLM: considered legacy and may be deprecated in the future. You should instead pass them via the `inputs` parameter. """ - runner_type = self.llm_engine.model_config.runner_type + model_config = self.llm_engine.model_config + runner_type = model_config.runner_type if runner_type != "pooling": messages = ["LLM.encode() is only supported for pooling models."] - supported_runner_types = self.llm_engine.model_config \ - .supported_runner_types - if "pooling" in supported_runner_types: + if "pooling" in model_config.supported_runner_types: messages.append( "Your model supports the 'pooling' runner, but is " f"currently initialized for the '{runner_type}' runner. " @@ -1119,13 +1117,13 @@ class LLM: # Use default pooling params. pooling_params = PoolingParams() elif isinstance(pooling_params, PoolingParams): - pooling_params.verify(self.llm_engine.model_config) + pooling_params.verify(model_config) else: for pooling_param in pooling_params: - pooling_param.verify(self.llm_engine.model_config) + pooling_param.verify(model_config) - tokenization_kwargs: dict[str, Any] = {} - _validate_truncation_size(self.llm_engine.model_config.max_model_len, + tokenization_kwargs = dict[str, Any]() + _validate_truncation_size(model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs) self._validate_and_add_requests( @@ -1178,9 +1176,10 @@ class LLM: A list of `EmbeddingRequestOutput` objects containing the embedding vectors in the same order as the input prompts. """ - if self.llm_engine.model_config.task != "embed": - raise ValueError( - "Embedding API is only enabled for `--task embed`") + model_config = self.llm_engine.model_config + if "embed" not in model_config.supported_tasks: + raise ValueError("Embedding API is not supported by this model. " + "Please set `--task embed`.") items = self.encode(prompts, truncate_prompt_tokens=truncate_prompt_tokens, @@ -1223,9 +1222,11 @@ class LLM: A list of `ClassificationRequestOutput` objects containing the embedding vectors in the same order as the input prompts. """ - if self.llm_engine.model_config.task != "classify": + model_config = self.llm_engine.model_config + if "classify" not in model_config.supported_tasks: raise ValueError( - "Classification API is only enabled for `--task classify`") + "Classification API is not supported by this model. " + "Please set `--task classify`.") items = self.encode(prompts, use_tqdm=use_tqdm, @@ -1392,13 +1393,12 @@ class LLM: A list of `ScoringRequestOutput` objects containing the generated scores in the same order as the input prompts. """ - runner_type = self.llm_engine.model_config.runner_type + model_config = self.llm_engine.model_config + runner_type = model_config.runner_type if runner_type != "pooling": messages = ["LLM.score() is only supported for pooling models."] - supported_runner_types = self.llm_engine.model_config \ - .supported_runner_types - if "pooling" in supported_runner_types: + if "pooling" in model_config.supported_runner_types: messages.append( "Your model supports the 'pooling' runner, but is " f"currently initialized for the '{runner_type}' runner. " @@ -1407,12 +1407,13 @@ class LLM: raise ValueError(" ".join(messages)) - if self.llm_engine.model_config.task not in ("embed", "classify"): - raise ValueError("Score API is only enabled for " - "`--task embed or --task classify`.") + if all(t not in model_config.supported_tasks + for t in ("embed", "classify")): + raise ValueError("Score API is not supported by this model. " + "Please set `--task embed` or `--task classify`.") - if (self.llm_engine.model_config.task == "classify" - and self.llm_engine.model_config.hf_config.num_labels != 1): + if (model_config.task == "classify" + and getattr(model_config.hf_config, "num_labels", 0) != 1): raise ValueError("Score API is only enabled for num_labels == 1.") # the tokenizer for models such as diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 2f53357e1d4cf..049a90fea1561 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1520,7 +1520,7 @@ async def init_app_state( reasoning_parser=args.reasoning_parser, enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_force_include_usage=args.enable_force_include_usage, - ) if model_config.runner_type == "generate" else None + ) if "generate" in model_config.supported_tasks else None state.openai_serving_chat = OpenAIServingChat( engine_client, model_config, @@ -1537,7 +1537,7 @@ async def init_app_state( reasoning_parser=args.reasoning_parser, enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_force_include_usage=args.enable_force_include_usage, - ) if model_config.runner_type == "generate" else None + ) if "generate" in model_config.supported_tasks else None state.openai_serving_completion = OpenAIServingCompletion( engine_client, model_config, @@ -1545,7 +1545,7 @@ async def init_app_state( request_logger=request_logger, return_tokens_as_token_ids=args.return_tokens_as_token_ids, enable_force_include_usage=args.enable_force_include_usage, - ) if model_config.runner_type == "generate" else None + ) if "generate" in model_config.supported_tasks else None state.openai_serving_pooling = OpenAIServingPooling( engine_client, model_config, @@ -1553,7 +1553,7 @@ async def init_app_state( request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, - ) if model_config.runner_type == "pooling" else None + ) if "pooling" in model_config.supported_tasks else None state.openai_serving_embedding = OpenAIServingEmbedding( engine_client, model_config, @@ -1561,22 +1561,24 @@ async def init_app_state( request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, - ) if model_config.task == "embed" else None + ) if "embed" in model_config.supported_tasks else None state.openai_serving_classification = ServingClassification( engine_client, model_config, state.openai_serving_models, request_logger=request_logger, - ) if model_config.task == "classify" else None + ) if "classify" in model_config.supported_tasks else None - enable_serving_reranking = (model_config.task == "classify" and getattr( - model_config.hf_config, "num_labels", 0) == 1) + enable_serving_reranking = ("classify" in model_config.supported_tasks + and getattr(model_config.hf_config, + "num_labels", 0) == 1) state.openai_serving_scores = ServingScores( engine_client, model_config, state.openai_serving_models, - request_logger=request_logger) if ( - model_config.task == "embed" or enable_serving_reranking) else None + request_logger=request_logger, + ) if ("embed" in model_config.supported_tasks + or enable_serving_reranking) else None state.openai_serving_tokenization = OpenAIServingTokenization( engine_client, @@ -1591,13 +1593,13 @@ async def init_app_state( model_config, state.openai_serving_models, request_logger=request_logger, - ) if model_config.runner_type == "transcription" else None + ) if "transcription" in model_config.supported_tasks else None state.openai_serving_translation = OpenAIServingTranslation( engine_client, model_config, state.openai_serving_models, request_logger=request_logger, - ) if model_config.runner_type == "transcription" else None + ) if "transcription" in model_config.supported_tasks else None state.task = model_config.task state.enable_server_load_tracking = args.enable_server_load_tracking diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index e112e2f893a09..3dc5826909a02 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -348,7 +348,7 @@ async def main(args): chat_template=None, chat_template_content_format="auto", enable_prompt_tokens_details=args.enable_prompt_tokens_details, - ) if model_config.runner_type == "generate" else None + ) if "generate" in model_config.supported_tasks else None openai_serving_embedding = OpenAIServingEmbedding( engine, model_config, @@ -356,17 +356,19 @@ async def main(args): request_logger=request_logger, chat_template=None, chat_template_content_format="auto", - ) if model_config.task == "embed" else None + ) if "embed" in model_config.supported_tasks else None - enable_serving_reranking = (model_config.task == "classify" and getattr( - model_config.hf_config, "num_labels", 0) == 1) + enable_serving_reranking = ("classify" in model_config.supported_tasks + and getattr(model_config.hf_config, + "num_labels", 0) == 1) - openai_serving_scores = (ServingScores( + openai_serving_scores = ServingScores( engine, model_config, openai_serving_models, request_logger=request_logger, - ) if (model_config.task == "embed" or enable_serving_reranking) else None) + ) if ("embed" in model_config.supported_tasks + or enable_serving_reranking) else None tracker = BatchProgressTracker() logger.info("Reading batch from %s...", args.input_file) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 99669a233634b..3a97641aa2f2e 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -694,6 +694,12 @@ class SupportsTranscription(Protocol): supports_transcription: ClassVar[Literal[True]] = True + supports_transcription_only: ClassVar[bool] = False + """ + Transcription models can opt out of text generation by setting this to + `True`. + """ + @classmethod def get_generation_prompt(cls, audio: np.ndarray, stt_config: SpeechToTextConfig, language: str, diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 5f9b145b6615e..e8530a555d286 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -284,6 +284,7 @@ class _ModelInfo: is_hybrid: bool has_noops: bool supports_transcription: bool + supports_transcription_only: bool supports_v0_only: bool @staticmethod @@ -299,6 +300,8 @@ class _ModelInfo: is_attention_free=is_attention_free(model), is_hybrid=is_hybrid(model), supports_transcription=supports_transcription(model), + supports_transcription_only=(supports_transcription(model) and + model.supports_transcription_only), supports_v0_only=supports_v0_only(model), has_noops=has_noops(model), ) @@ -573,6 +576,13 @@ class _ModelRegistry: model_cls, _ = self.inspect_model_cls(architectures) return model_cls.supports_transcription + def is_transcription_only_model( + self, + architectures: Union[str, list[str]], + ) -> bool: + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.supports_transcription_only + def is_v1_compatible( self, architectures: Union[str, list[str]], diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 1a7982e48e4b1..08aed2205e0a5 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -772,6 +772,9 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, ".fc2.": ".mlp.fc2." }) + # Whisper only supports audio-conditioned generation. + supports_transcription_only = True + @classmethod def validate_language(cls, language: str) -> bool: if language in ISO639_1_SUPPORTED_LANGS: