From a2dd48c386f787cdfbe35099e0d273f19748dbb6 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 28 Feb 2025 03:14:55 +0800 Subject: [PATCH] [VLM] Deprecate legacy input mapper for OOT multimodal models (#13979) Signed-off-by: DarkLight1337 --- vllm/config.py | 45 ++++++++++++++++++++------------------- vllm/inputs/preprocess.py | 14 +++++++----- 2 files changed, 32 insertions(+), 27 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index d1384c6375f30..cb683d19386b9 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -400,7 +400,7 @@ class ModelConfig: else: self.override_neuron_config = None - supported_tasks, task = self._resolve_task(task, self.hf_config) + supported_tasks, task = self._resolve_task(task) self.supported_tasks = supported_tasks self.task: Final = task if self.task in ("draft", "generate"): @@ -418,6 +418,14 @@ class ModelConfig: self._verify_cuda_graph() self._verify_bnb_config() + @property + def registry(self): + return ModelRegistry + + @property + def architectures(self) -> list[str]: + return getattr(self.hf_config, "architectures", []) + def maybe_pull_model_tokenizer_for_s3(self, model: str, tokenizer: str) -> None: """ @@ -446,8 +454,7 @@ class ModelConfig: def _init_multimodal_config( self, limit_mm_per_prompt: Optional[Mapping[str, int]] ) -> Optional["MultiModalConfig"]: - architectures = getattr(self.hf_config, "architectures", []) - if ModelRegistry.is_multimodal_model(architectures): + if self.registry.is_multimodal_model(self.architectures): return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {}) if limit_mm_per_prompt: @@ -480,16 +487,13 @@ class ModelConfig: return None def _init_attention_free(self) -> bool: - architectures = getattr(self.hf_config, "architectures", []) - return ModelRegistry.is_attention_free_model(architectures) + return self.registry.is_attention_free_model(self.architectures) def _init_is_hybrid(self) -> bool: - architectures = getattr(self.hf_config, "architectures", []) - return ModelRegistry.is_hybrid_model(architectures) + return self.registry.is_hybrid_model(self.architectures) def _init_has_inner_state(self) -> bool: - architectures = getattr(self.hf_config, "architectures", []) - return ModelRegistry.model_has_inner_state(architectures) + return self.registry.model_has_inner_state(self.architectures) def _verify_tokenizer_mode(self) -> None: tokenizer_mode = self.tokenizer_mode.lower() @@ -507,9 +511,9 @@ class ModelConfig: model_id = self.model if get_pooling_config(model_id, self.revision): return "embed" - if ModelRegistry.is_cross_encoder_model(architectures): + if self.registry.is_cross_encoder_model(architectures): return "score" - if ModelRegistry.is_transcription_model(architectures): + if self.registry.is_transcription_model(architectures): return "transcription" suffix_to_preferred_task: List[Tuple[str, _ResolvedTask]] = [ @@ -522,7 +526,7 @@ class ModelConfig: ("EmbeddingModel", "embed"), ("RewardModel", "reward"), ] - _, arch = ModelRegistry.inspect_model_cls(architectures) + _, arch = self.registry.inspect_model_cls(architectures) for suffix, pref_task in suffix_to_preferred_task: if arch.endswith(suffix) and pref_task in supported_tasks: @@ -533,20 +537,19 @@ class ModelConfig: def _resolve_task( self, task_option: Union[TaskOption, Literal["draft"]], - hf_config: PretrainedConfig, ) -> Tuple[Set[_ResolvedTask], _ResolvedTask]: if task_option == "draft": return {"draft"}, "draft" - architectures = getattr(hf_config, "architectures", []) + registry = self.registry + architectures = self.architectures runner_support: Dict[RunnerType, bool] = { # NOTE: Listed from highest to lowest priority, # in case the model supports multiple of them - "transcription": - ModelRegistry.is_transcription_model(architectures), - "generate": ModelRegistry.is_text_generation_model(architectures), - "pooling": ModelRegistry.is_pooling_model(architectures), + "transcription": registry.is_transcription_model(architectures), + "generate": registry.is_text_generation_model(architectures), + "pooling": registry.is_pooling_model(architectures), } supported_runner_types_lst: List[RunnerType] = [ runner_type @@ -755,8 +758,7 @@ class ModelConfig: pipeline_parallel_size = parallel_config.pipeline_parallel_size if pipeline_parallel_size > 1: - architectures = getattr(self.hf_config, "architectures", []) - if not ModelRegistry.is_pp_supported_model(architectures): + if not self.registry.is_pp_supported_model(self.architectures): raise NotImplementedError( "Pipeline parallelism is not supported for this model. " "Supported models implement the `SupportsPP` interface.") @@ -1023,8 +1025,7 @@ class ModelConfig: @property def is_cross_encoder(self) -> bool: - architectures = getattr(self.hf_config, "architectures", []) - return ModelRegistry.is_cross_encoder_model(architectures) + return self.registry.is_cross_encoder_model(self.architectures) @property def use_mla(self) -> bool: diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index bc5856990da6f..206a76e52b7ab 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -236,11 +236,15 @@ class InputPreprocessor: # updated to use the new multi-modal processor can_process_multimodal = self.mm_registry.has_processor(model_config) if not can_process_multimodal: - logger.info_once( - "Your model uses the legacy input pipeline instead of the new " - "multi-modal processor. Please note that the legacy pipeline " - "will be removed in a future release. For more details, see: " - "https://github.com/vllm-project/vllm/issues/10114") + from vllm.model_executor.models.registry import _VLLM_MODELS + if not any(arch in _VLLM_MODELS + for arch in model_config.architectures): + logger.warning_once( + "Your model uses the legacy input pipeline, which will be " + "removed in an upcoming release. " + "Please upgrade to the new multi-modal processing pipeline " + "(https://docs.vllm.ai/en/latest/design/mm_processing.html)" + ) return can_process_multimodal