mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:04:58 +08:00
[Bugfix] Update attention interface in Whisper (#11784)
Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
parent
b278557935
commit
0f3f3c86ec
@ -106,6 +106,7 @@ class WhisperAttention(nn.Module):
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
attn_type=self.attn_type,
|
||||
)
|
||||
|
||||
def _init_qkv(
|
||||
@ -134,12 +135,7 @@ class WhisperAttention(nn.Module):
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
attn_output = self.attn(q,
|
||||
k,
|
||||
v,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
attn_type=self.attn_type)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
|
||||
output, _ = self.out_proj(attn_output)
|
||||
|
||||
@ -164,6 +160,7 @@ class WhisperCrossAttention(WhisperAttention):
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
attn_type=AttentionType.ENCODER_DECODER,
|
||||
)
|
||||
|
||||
def _init_qkv(
|
||||
@ -207,12 +204,13 @@ class WhisperCrossAttention(WhisperAttention):
|
||||
else:
|
||||
k = v = None
|
||||
|
||||
attn_output = self.attn(q,
|
||||
k,
|
||||
v,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
attn_type=AttentionType.ENCODER_DECODER)
|
||||
attn_output = self.attn(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
)
|
||||
|
||||
output, _ = self.out_proj(attn_output)
|
||||
|
||||
@ -734,4 +732,4 @@ class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
loaded_weights = [(name, loaded_weight)
|
||||
for name, loaded_weight in weights]
|
||||
mapper = WeightsMapper({".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."})
|
||||
return loader.load_weights(loaded_weights, mapper=mapper)
|
||||
return loader.load_weights(loaded_weights, mapper=mapper)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user