Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
NickLucche 2025-12-19 12:18:58 +00:00
parent 87c0d52e87
commit 7c95fd8279

View File

@ -24,8 +24,8 @@ from vllm.attention.backends.abstract import (
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.attention.layers.cross_attention import CrossAttention from vllm.attention.layers.cross_attention import CrossAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs.data import PromptType from vllm.inputs.data import PromptType
@ -106,9 +106,6 @@ class WhisperEncoderAttention(MMEncoderAttention):
def __init__(self, num_heads: int, head_size: int, scale: float, num_kv_heads: int): def __init__(self, num_heads: int, head_size: int, scale: float, num_kv_heads: int):
super().__init__(num_heads, head_size, scale, num_kv_heads) super().__init__(num_heads, head_size, scale, num_kv_heads)
self.num_heads = num_heads
self.head_size = head_size
self.num_kv_heads = num_kv_heads
def forward( def forward(
self, self,
@ -193,9 +190,9 @@ class WhisperAttention(nn.Module):
) )
if attn_type == AttentionType.ENCODER: if attn_type == AttentionType.ENCODER:
self.attn = WhisperEncoderAttention( self.attn = WhisperEncoderAttention(
num_heads=self.num_heads, self.num_heads,
head_size=self.head_dim, self.head_dim,
scale=self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
) )
elif self.attn_type == AttentionType.ENCODER_DECODER: elif self.attn_type == AttentionType.ENCODER_DECODER: