[Bugfix] Update attention interface in Whisper (#11784)

Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Roger Wang 2025-01-06 20:36:24 -08:00 committed by GitHub
parent b278557935
commit 0f3f3c86ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -106,6 +106,7 @@ class WhisperAttention(nn.Module):
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=self.attn_type,
)
def _init_qkv(
@ -134,12 +135,7 @@ class WhisperAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
attn_output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
attn_type=self.attn_type)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.out_proj(attn_output)
@ -164,6 +160,7 @@ class WhisperCrossAttention(WhisperAttention):
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
attn_type=AttentionType.ENCODER_DECODER,
)
def _init_qkv(
@ -207,12 +204,13 @@ class WhisperCrossAttention(WhisperAttention):
else:
k = v = None
attn_output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
attn_type=AttentionType.ENCODER_DECODER)
attn_output = self.attn(
q,
k,
v,
kv_cache,
attn_metadata,
)
output, _ = self.out_proj(attn_output)
@ -734,4 +732,4 @@ class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal):
loaded_weights = [(name, loaded_weight)
for name, loaded_weight in weights]
mapper = WeightsMapper({".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."})
return loader.load_weights(loaded_weights, mapper=mapper)
return loader.load_weights(loaded_weights, mapper=mapper)