mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 07:25:01 +08:00
[Bugfix] Update attention interface in Whisper (#11784)
Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
parent
b278557935
commit
0f3f3c86ec
@ -106,6 +106,7 @@ class WhisperAttention(nn.Module):
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
attn_type=self.attn_type,
|
||||
)
|
||||
|
||||
def _init_qkv(
|
||||
@ -134,12 +135,7 @@ class WhisperAttention(nn.Module):
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
attn_output = self.attn(q,
|
||||
k,
|
||||
v,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
attn_type=self.attn_type)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
|
||||
output, _ = self.out_proj(attn_output)
|
||||
|
||||
@ -164,6 +160,7 @@ class WhisperCrossAttention(WhisperAttention):
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
attn_type=AttentionType.ENCODER_DECODER,
|
||||
)
|
||||
|
||||
def _init_qkv(
|
||||
@ -207,12 +204,13 @@ class WhisperCrossAttention(WhisperAttention):
|
||||
else:
|
||||
k = v = None
|
||||
|
||||
attn_output = self.attn(q,
|
||||
attn_output = self.attn(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
attn_type=AttentionType.ENCODER_DECODER)
|
||||
)
|
||||
|
||||
output, _ = self.out_proj(attn_output)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user