Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
NickLucche 2025-12-19 12:18:58 +00:00
parent 87c0d52e87
commit 7c95fd8279

View File

@ -24,8 +24,8 @@ from vllm.attention.backends.abstract import (
from vllm.attention.layer import Attention
from vllm.attention.layers.cross_attention import CrossAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs.data import PromptType
@ -106,9 +106,6 @@ class WhisperEncoderAttention(MMEncoderAttention):
def __init__(self, num_heads: int, head_size: int, scale: float, num_kv_heads: int):
super().__init__(num_heads, head_size, scale, num_kv_heads)
self.num_heads = num_heads
self.head_size = head_size
self.num_kv_heads = num_kv_heads
def forward(
self,
@ -193,9 +190,9 @@ class WhisperAttention(nn.Module):
)
if attn_type == AttentionType.ENCODER:
self.attn = WhisperEncoderAttention(
num_heads=self.num_heads,
head_size=self.head_dim,
scale=self.scaling,
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
)
elif self.attn_type == AttentionType.ENCODER_DECODER: