[Core] Align whisper closer to other multimodal models (#27292)

Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Russell Bryant 2025-11-21 07:01:54 -05:00 committed by GitHub
parent aab0102a26
commit cca2d2cdbe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 21 additions and 41 deletions

View File

@ -599,15 +599,16 @@ class WhisperModel(nn.Module):
def forward(
self,
input_features: torch.Tensor | list[torch.Tensor] | None,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
encoder_outputs: list[torch.Tensor],
) -> torch.Tensor:
encoder_outputs = self.get_encoder_outputs(input_features)
assert len(encoder_outputs) in (0, 1)
enc_states = encoder_outputs[0] if len(encoder_outputs) == 1 else None
decoder_outputs = self.decoder(
input_ids=input_ids,
positions=positions,
encoder_hidden_states=encoder_outputs,
encoder_hidden_states=enc_states,
)
return decoder_outputs
@ -894,13 +895,15 @@ class WhisperForConditionalGeneration(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
encoder_outputs: list[torch.Tensor] | None = None,
**kwargs,
) -> torch.Tensor:
audio_input = self._parse_and_validate_audio_input(**kwargs)
if encoder_outputs is None:
encoder_outputs = []
decoder_outputs = self.model(
input_features=audio_input["input_features"],
input_ids=input_ids,
positions=positions,
encoder_outputs=encoder_outputs,
)
return decoder_outputs

View File

@ -1923,14 +1923,16 @@ class GPUModelRunner(
return mm_kwargs, mm_hashes_pos
def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
def _execute_mm_encoder(
self, scheduler_output: "SchedulerOutput"
) -> list[torch.Tensor]:
# Batch the multi-modal inputs using the helper method.
mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler(
scheduler_output
)
if not mm_kwargs:
return
return []
# Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one,
@ -2007,6 +2009,8 @@ class GPUModelRunner(
logger.debug("Finish execute for mm hash %s", mm_hash)
self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash)
return encoder_outputs
def _gather_mm_embeddings(
self,
scheduler_output: "SchedulerOutput",
@ -2095,38 +2099,6 @@ class GPUModelRunner(
return mm_embeds, is_mm_embed
def _extract_encoder_inputs(
self,
scheduler_output: "SchedulerOutput",
) -> dict[str, torch.Tensor]:
"""Extract encoder inputs for encoder-decoder models.
This method extracts multimodal input features from scheduled encoder
inputs and formats them for the encoder-decoder model forward pass.
"""
# Batch the multi-modal inputs using the helper method.
mm_kwargs, _ = self._batch_mm_kwargs_from_scheduler(scheduler_output)
if not mm_kwargs:
return {}
# Group MM kwargs by modality and extract features
model = cast(SupportsMultiModal, self.model)
encoder_features = {}
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs,
device=self.device,
pin_memory=self.pin_memory,
merge_by_field_config=model.merge_by_field_config,
multimodal_cpu_fields=model.multimodal_cpu_fields,
):
# Add the grouped features to encoder_features dict
# This allows the model to receive them as kwargs (e.g.,
# input_features=...)
encoder_features.update(mm_kwargs_group)
return encoder_features
def get_model(self) -> nn.Module:
# get raw model out of the cudagraph wrapper.
if isinstance(self.model, (CUDAGraphWrapper, UBatchWrapper)):
@ -2416,8 +2388,13 @@ class GPUModelRunner(
self.model_config.is_encoder_decoder
and scheduler_output.scheduled_encoder_inputs
):
encoder_inputs = self._extract_encoder_inputs(scheduler_output)
model_kwargs.update(encoder_inputs)
# Run the encoder, just like we do with other multimodal inputs.
# For an encoder-decoder model, our processing here is a bit
# simpler, because the outputs are just passed to the decoder.
# We are not doing any prompt replacement. We also will only
# ever have a single encoder input.
encoder_outputs = self._execute_mm_encoder(scheduler_output)
model_kwargs.update({"encoder_outputs": encoder_outputs})
return (
input_ids,