diff --git a/vllm/attention/layers/mm_encoder_attention.py b/vllm/attention/layers/mm_encoder_attention.py index 1c1623b13f55a..e1011cc361f55 100644 --- a/vllm/attention/layers/mm_encoder_attention.py +++ b/vllm/attention/layers/mm_encoder_attention.py @@ -134,6 +134,7 @@ class MMEncoderAttention(CustomOp): k=key, v=value, cu_seqlens=cu_seqlens, + softmax_scale=self.scale, ) if is_reshaped: output = output.reshape(bsz, q_len, -1) @@ -172,6 +173,7 @@ class MMEncoderAttention(CustomOp): batch_size=bsz, is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA), fa_version=self._fa_version, + softmax_scale=self.scale, ) if is_reshaped: output = output.reshape(bsz, q_len, -1) diff --git a/vllm/attention/ops/vit_attn_wrappers.py b/vllm/attention/ops/vit_attn_wrappers.py index 2204382a35e2a..dc5d5e8c4904a 100644 --- a/vllm/attention/ops/vit_attn_wrappers.py +++ b/vllm/attention/ops/vit_attn_wrappers.py @@ -29,6 +29,7 @@ def flash_attn_maxseqlen_wrapper( fa_version: int | None, cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, + softmax_scale: float | None = None, ) -> torch.Tensor: kwargs = {} if is_rocm_aiter: @@ -57,6 +58,7 @@ def flash_attn_maxseqlen_wrapper( max_seqlen_k=max_seqlen, dropout_p=0.0, causal=False, + softmax_scale=softmax_scale, **kwargs, ) context_layer = einops.rearrange(output, "(b s) h d -> b s h d", b=batch_size) @@ -67,11 +69,12 @@ def flash_attn_maxseqlen_wrapper_fake( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - cu_seqlens: torch.Tensor, - max_seqlen: torch.Tensor, batch_size: int, is_rocm_aiter: bool, - fa_version: int | None, + fa_version: int, + cu_seqlens: torch.Tensor, + max_seqlen: torch.Tensor, + softmax_scale: float, ) -> torch.Tensor: return torch.empty_like(q) @@ -92,6 +95,7 @@ def vit_flash_attn_wrapper( fa_version: int | None, cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, + softmax_scale: float | None = None, ) -> torch.Tensor: return torch.ops.vllm.flash_attn_maxseqlen_wrapper( q, @@ -102,16 +106,17 @@ def vit_flash_attn_wrapper( fa_version, cu_seqlens, max_seqlen, + softmax_scale, ) -def apply_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: +def apply_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, softmax_scale: float | None = None) -> torch.Tensor: """ Input shape: (batch_size x seq_len x num_heads x head_size) """ q, k, v = (einops.rearrange(x, "b s h d -> b h s d") for x in [q, k, v]) - output = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0) + output = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, scale=softmax_scale) output = einops.rearrange(output, "b h s d -> b s h d ") return output @@ -123,6 +128,7 @@ def torch_sdpa_wrapper( k: torch.Tensor, v: torch.Tensor, cu_seqlens: torch.Tensor | None = None, + softmax_scale: float | None = None, ) -> torch.Tensor: # Never remove the contiguous logic for ROCm # Without it, hallucinations occur with the backend @@ -141,7 +147,7 @@ def torch_sdpa_wrapper( k_chunks = torch.split(k, lens, dim=1) v_chunks = torch.split(v, lens, dim=1) for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks): - output_i = apply_sdpa(q_i, k_i, v_i) + output_i = apply_sdpa(q_i, k_i, v_i, softmax_scale) outputs.append(output_i) context_layer = torch.cat(outputs, dim=1) return context_layer @@ -152,6 +158,7 @@ def torch_sdpa_wrapper_fake( k: torch.Tensor, v: torch.Tensor, cu_seqlens: torch.Tensor, + softmax_scale: float, ) -> torch.Tensor: return torch.empty_like(q) @@ -168,5 +175,6 @@ def vit_torch_sdpa_wrapper( k: torch.Tensor, v: torch.Tensor, cu_seqlens: torch.Tensor | None = None, + softmax_scale: float | None = None, ) -> torch.Tensor: - return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, cu_seqlens) + return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, cu_seqlens, softmax_scale) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index cd527e4198557..299285c5525de 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -26,7 +26,7 @@ from vllm.utils.math_utils import round_up from vllm.utils.torch_utils import is_torch_equal_or_newer if TYPE_CHECKING: - from vllm.config import VllmConfig + from vllm.config import ModelConfig, VllmConfig else: VllmConfig = object @@ -829,6 +829,12 @@ class CompilationConfig: if self.backend == "": self.backend = current_platform.get_compile_backend() + def verify_with_model_config(self, model_config: "ModelConfig") -> None: + if model_config.is_encoder_decoder: + # Does not yet support encoder-decoder models. + self.inductor_compile_config["combo_kernels"] = False + self.inductor_compile_config["benchmark_combo_kernel"] = False + def init_backend(self, vllm_config: "VllmConfig") -> str | Callable: """ Initialize the backend for the compilation config from a vllm config. diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 0439dc52e7e6f..865c57689ba67 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -521,6 +521,7 @@ class VllmConfig: if self.model_config is not None: self.model_config.verify_with_parallel_config(self.parallel_config) self.model_config.verify_dual_chunk_attention_config(self.load_config) + self.compilation_config.verify_with_model_config(self.model_config) self.cache_config.verify_with_parallel_config(self.parallel_config) @@ -753,14 +754,19 @@ class VllmConfig: if ( self.model_config and self.model_config.architecture == "WhisperForConditionalGeneration" - and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn" ): - logger.warning( - "Whisper is known to have issues with " - "forked workers. If startup is hanging, " - "try setting 'VLLM_WORKER_MULTIPROC_METHOD' " - "to 'spawn'." - ) + if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn": + logger.warning( + "Whisper is known to have issues with " + "forked workers. If startup is hanging, " + "try setting 'VLLM_WORKER_MULTIPROC_METHOD' " + "to 'spawn'." + ) + if self.optimization_level > OptimizationLevel.O0: + self.compilation_config.compile_mm_encoder = True + logger.info( + "Enabling encoder compilation for Whisper for better performance." + ) if ( self.kv_events_config is not None diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index cbba1af89190c..8bc5bd3107dd8 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -816,7 +816,7 @@ class VoxtralEncoderModel(nn.Module): input_embeds, chunks_per_example = self.prepare_inputs_for_conv(input_features) # [total_num_chunks, ceil(chunk_size / downsample_factor), hidden_size] - out = self.whisper_encoder([input_embeds]) + out = self.whisper_encoder(input_embeds) # Re-concatenate the chunks chunk_idx = 0 diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index f1bae28debad2..be4b123d60f7f 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -25,6 +25,7 @@ from vllm.attention.layer import Attention from vllm.attention.layers.cross_attention import CrossAttention from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig +from vllm.compilation.decorators import support_torch_compile from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size from vllm.inputs.data import PromptType @@ -81,6 +82,11 @@ class WhisperPosEmbedType(enum.Enum): LEARNED = "learned" +def should_torch_compile_encoder(vllm_config: VllmConfig) -> bool: + """Callable to be passed to `@support_torch_compile`'s `enable_if` argument.""" + return vllm_config.compilation_config.compile_mm_encoder + + class WhisperAudioInputs(TensorSchema): """ Dimensions: @@ -98,6 +104,12 @@ class WhisperAudioInputs(TensorSchema): class WhisperEncoderAttention(MMEncoderAttention): """Multi-headed attention for Whisper encoder with 2D tensor support.""" + def __init__(self, num_heads: int, head_size: int, scale: float, num_kv_heads: int): + super().__init__(num_heads, head_size, scale, num_kv_heads) + self.num_heads = num_heads + self.head_size = head_size + self.num_kv_heads = num_kv_heads + def forward( self, query: torch.Tensor, @@ -181,9 +193,9 @@ class WhisperAttention(nn.Module): ) if attn_type == AttentionType.ENCODER: self.attn = WhisperEncoderAttention( - self.num_heads, - self.head_dim, - self.scaling, + num_heads=self.num_heads, + head_size=self.head_dim, + scale=self.scaling, num_kv_heads=self.num_kv_heads, ) elif self.attn_type == AttentionType.ENCODER_DECODER: @@ -347,7 +359,13 @@ class WhisperMLP(nn.Module): class WhisperEncoderLayer(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__( + self, + *, + vllm_config: VllmConfig, + skip_overflow_clamp: bool = False, + prefix: str = "", + ): super().__init__() config = vllm_config.model_config.hf_config is_causal = getattr(config, "is_causal", False) @@ -355,6 +373,7 @@ class WhisperEncoderLayer(nn.Module): block_pool_size = getattr(config, "block_pool_size", 1) cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config + self._skip_overflow_clamp = skip_overflow_clamp self.embed_dim = config.d_model self.self_attn = WhisperAttention( @@ -390,7 +409,9 @@ class WhisperEncoderLayer(nn.Module): hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - hidden_states = cast_overflow_tensors(hidden_states) + # Not compatible with torch.compile + if not self._skip_overflow_clamp: + hidden_states = cast_overflow_tensors(hidden_states) return hidden_states @@ -454,6 +475,9 @@ class WhisperDecoderLayer(nn.Module): return hidden_states +@support_torch_compile( + dynamic_arg_dims={"input_features": 0}, enable_if=should_torch_compile_encoder +) class WhisperEncoder(nn.Module): def __init__( self, *, vllm_config: VllmConfig, prefix: str = "", init_in_fp32: bool = False @@ -479,7 +503,9 @@ class WhisperEncoder(nn.Module): self.start_layer, self.end_layer, self.layers = make_layers( config.encoder_layers, lambda prefix: WhisperEncoderLayer( - vllm_config=vllm_config, prefix=f"{prefix}.layers" + vllm_config=vllm_config, + prefix=f"{prefix}.layers", + skip_overflow_clamp=should_torch_compile_encoder(vllm_config), ), prefix=f"{prefix}.layers", ) @@ -510,38 +536,22 @@ class WhisperEncoder(nn.Module): self.embed_positions.weight.copy_( sinusoids(*self.embed_positions.weight.shape) ) + # TODO check raise ValueError(f"Unknown pos_embed_type: {self.pos_embed_type}") def forward_conv( self, input_features: torch.Tensor | list[torch.Tensor] ) -> torch.Tensor: - hidden_states = [] - input_is_batched = False - for features in input_features: - embeds = nn.functional.gelu(self.conv1(features)) - embeds = nn.functional.gelu(self.conv2(embeds)) + embeds = nn.functional.gelu(self.conv1(input_features)) + embeds = nn.functional.gelu(self.conv2(embeds)) + embeds = embeds.transpose(-1, -2) - if self.pos_embed_type in ( - WhisperPosEmbedType.SINUSOIDAL, - WhisperPosEmbedType.LEARNED, - ): - embeds = embeds.transpose(-1, -2) - embeds = ( - embeds + self.embed_positions.weight[: embeds.size(-2), :] - ).to(embeds.dtype) - elif self.pos_embed_type == WhisperPosEmbedType.NOPE: - embeds = embeds.transpose(-1, -2).to(embeds.dtype) - else: - raise ValueError(f"Unknown pos_embed_type: {self.pos_embed_type}") + if self.pos_embed_type == WhisperPosEmbedType.NOPE: + return embeds - hidden_states.append(embeds) - input_is_batched = embeds.ndim > 2 # Input to MHA must be B x T x D - if input_is_batched: - # Models using WhisperEncoder may handle batching internally. - hidden_states = torch.cat(hidden_states) - else: - hidden_states = torch.stack(hidden_states, dim=0) - + hidden_states = (embeds + self.embed_positions.weight[: embeds.size(-2), :]).to( + embeds.dtype + ) return hidden_states def forward_layers(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -630,7 +640,7 @@ class WhisperModel(nn.Module): def get_encoder_outputs( self, - input_features: torch.Tensor | list[torch.Tensor] | None, + input_features: torch.Tensor | None, ) -> torch.Tensor | None: if input_features is None: return None @@ -951,7 +961,6 @@ class WhisperForConditionalGeneration( if input_features is not None: input_features = json_map_leaves(lambda x: x.to(self.dtype), input_features) - return WhisperAudioInputs(input_features=input_features) def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: