mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 07:05:01 +08:00
[Bugfix] Update attention interface in Whisper (#11784)
Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
parent
b278557935
commit
0f3f3c86ec
@ -106,6 +106,7 @@ class WhisperAttention(nn.Module):
|
|||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
|
attn_type=self.attn_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init_qkv(
|
def _init_qkv(
|
||||||
@ -134,12 +135,7 @@ class WhisperAttention(nn.Module):
|
|||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
|
||||||
attn_output = self.attn(q,
|
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||||
k,
|
|
||||||
v,
|
|
||||||
kv_cache,
|
|
||||||
attn_metadata,
|
|
||||||
attn_type=self.attn_type)
|
|
||||||
|
|
||||||
output, _ = self.out_proj(attn_output)
|
output, _ = self.out_proj(attn_output)
|
||||||
|
|
||||||
@ -164,6 +160,7 @@ class WhisperCrossAttention(WhisperAttention):
|
|||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
|
attn_type=AttentionType.ENCODER_DECODER,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init_qkv(
|
def _init_qkv(
|
||||||
@ -207,12 +204,13 @@ class WhisperCrossAttention(WhisperAttention):
|
|||||||
else:
|
else:
|
||||||
k = v = None
|
k = v = None
|
||||||
|
|
||||||
attn_output = self.attn(q,
|
attn_output = self.attn(
|
||||||
|
q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
attn_type=AttentionType.ENCODER_DECODER)
|
)
|
||||||
|
|
||||||
output, _ = self.out_proj(attn_output)
|
output, _ = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user