From cca2d2cdbe56529205c10e58363c7bd2d31e15df Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Fri, 21 Nov 2025 07:01:54 -0500 Subject: [PATCH] [Core] Align whisper closer to other multimodal models (#27292) Signed-off-by: Russell Bryant --- vllm/model_executor/models/whisper.py | 13 ++++--- vllm/v1/worker/gpu_model_runner.py | 49 +++++++-------------------- 2 files changed, 21 insertions(+), 41 deletions(-) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 91a10b95a08c..50587c627160 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -599,15 +599,16 @@ class WhisperModel(nn.Module): def forward( self, - input_features: torch.Tensor | list[torch.Tensor] | None, input_ids: torch.Tensor | None, positions: torch.Tensor, + encoder_outputs: list[torch.Tensor], ) -> torch.Tensor: - encoder_outputs = self.get_encoder_outputs(input_features) + assert len(encoder_outputs) in (0, 1) + enc_states = encoder_outputs[0] if len(encoder_outputs) == 1 else None decoder_outputs = self.decoder( input_ids=input_ids, positions=positions, - encoder_hidden_states=encoder_outputs, + encoder_hidden_states=enc_states, ) return decoder_outputs @@ -894,13 +895,15 @@ class WhisperForConditionalGeneration( self, input_ids: torch.Tensor, positions: torch.Tensor, + encoder_outputs: list[torch.Tensor] | None = None, **kwargs, ) -> torch.Tensor: - audio_input = self._parse_and_validate_audio_input(**kwargs) + if encoder_outputs is None: + encoder_outputs = [] decoder_outputs = self.model( - input_features=audio_input["input_features"], input_ids=input_ids, positions=positions, + encoder_outputs=encoder_outputs, ) return decoder_outputs diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4c65a5e9b029..e786cd8bc7c9 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1923,14 +1923,16 @@ class GPUModelRunner( return mm_kwargs, mm_hashes_pos - def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): + def _execute_mm_encoder( + self, scheduler_output: "SchedulerOutput" + ) -> list[torch.Tensor]: # Batch the multi-modal inputs using the helper method. mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler( scheduler_output ) if not mm_kwargs: - return + return [] # Batch mm inputs as much as we can: if a request in the batch has # multiple modalities or a different modality than the previous one, @@ -2007,6 +2009,8 @@ class GPUModelRunner( logger.debug("Finish execute for mm hash %s", mm_hash) self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash) + return encoder_outputs + def _gather_mm_embeddings( self, scheduler_output: "SchedulerOutput", @@ -2095,38 +2099,6 @@ class GPUModelRunner( return mm_embeds, is_mm_embed - def _extract_encoder_inputs( - self, - scheduler_output: "SchedulerOutput", - ) -> dict[str, torch.Tensor]: - """Extract encoder inputs for encoder-decoder models. - - This method extracts multimodal input features from scheduled encoder - inputs and formats them for the encoder-decoder model forward pass. - """ - # Batch the multi-modal inputs using the helper method. - mm_kwargs, _ = self._batch_mm_kwargs_from_scheduler(scheduler_output) - - if not mm_kwargs: - return {} - - # Group MM kwargs by modality and extract features - model = cast(SupportsMultiModal, self.model) - encoder_features = {} - for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( - mm_kwargs, - device=self.device, - pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, - multimodal_cpu_fields=model.multimodal_cpu_fields, - ): - # Add the grouped features to encoder_features dict - # This allows the model to receive them as kwargs (e.g., - # input_features=...) - encoder_features.update(mm_kwargs_group) - - return encoder_features - def get_model(self) -> nn.Module: # get raw model out of the cudagraph wrapper. if isinstance(self.model, (CUDAGraphWrapper, UBatchWrapper)): @@ -2416,8 +2388,13 @@ class GPUModelRunner( self.model_config.is_encoder_decoder and scheduler_output.scheduled_encoder_inputs ): - encoder_inputs = self._extract_encoder_inputs(scheduler_output) - model_kwargs.update(encoder_inputs) + # Run the encoder, just like we do with other multimodal inputs. + # For an encoder-decoder model, our processing here is a bit + # simpler, because the outputs are just passed to the decoder. + # We are not doing any prompt replacement. We also will only + # ever have a single encoder input. + encoder_outputs = self._execute_mm_encoder(scheduler_output) + model_kwargs.update({"encoder_outputs": encoder_outputs}) return ( input_ids,