[Bugfix] Initialize attention bias on the same device as Query/Key/Value (#13468)

This commit is contained in:
Junlin Zhou 2025-02-25 18:13:09 +08:00 committed by GitHub
parent 32c3b6bfd1
commit 75e9d49796
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -673,7 +673,9 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# Cross-attention mask is non-causal
attn_bias = BlockDiagonalMask.from_seqlens(
attn_metadata.seq_lens, attn_metadata.encoder_seq_lens)
attn_metadata.seq_lens,
attn_metadata.encoder_seq_lens,
device=query.device)
# Encoder branch of encoder-decoder model uses
# attn_metadata.encoder_seq_lens
@ -683,7 +685,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# Encoder self-attention mask is non-causal
attn_bias = BlockDiagonalMask.from_seqlens(
attn_metadata.encoder_seq_lens)
attn_metadata.encoder_seq_lens, device=query.device)
# Self-attention block of encoder-only model just
# uses the seq_lens directly.
@ -692,7 +694,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# Encoder self-attention mask is non-causal
attn_bias = BlockDiagonalMask.from_seqlens(
attn_metadata.seq_lens)
attn_metadata.seq_lens, device=query.device)
# Self-attention block of decoder branch just
# uses the seq_lens directly
@ -701,7 +703,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# Decoder self-attention mask is causal
attn_bias = BlockDiagonalCausalMask.from_seqlens(
attn_metadata.seq_lens)
attn_metadata.seq_lens, device=query.device)
else:
raise ValueError("Unknown AttentionType: %s", attn_type)