mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 22:04:58 +08:00
[Bugfix] Initialize attention bias on the same device as Query/Key/Value (#13468)
This commit is contained in:
parent
32c3b6bfd1
commit
75e9d49796
@ -673,7 +673,9 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
|
|
||||||
# Cross-attention mask is non-causal
|
# Cross-attention mask is non-causal
|
||||||
attn_bias = BlockDiagonalMask.from_seqlens(
|
attn_bias = BlockDiagonalMask.from_seqlens(
|
||||||
attn_metadata.seq_lens, attn_metadata.encoder_seq_lens)
|
attn_metadata.seq_lens,
|
||||||
|
attn_metadata.encoder_seq_lens,
|
||||||
|
device=query.device)
|
||||||
|
|
||||||
# Encoder branch of encoder-decoder model uses
|
# Encoder branch of encoder-decoder model uses
|
||||||
# attn_metadata.encoder_seq_lens
|
# attn_metadata.encoder_seq_lens
|
||||||
@ -683,7 +685,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
|
|
||||||
# Encoder self-attention mask is non-causal
|
# Encoder self-attention mask is non-causal
|
||||||
attn_bias = BlockDiagonalMask.from_seqlens(
|
attn_bias = BlockDiagonalMask.from_seqlens(
|
||||||
attn_metadata.encoder_seq_lens)
|
attn_metadata.encoder_seq_lens, device=query.device)
|
||||||
|
|
||||||
# Self-attention block of encoder-only model just
|
# Self-attention block of encoder-only model just
|
||||||
# uses the seq_lens directly.
|
# uses the seq_lens directly.
|
||||||
@ -692,7 +694,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
|
|
||||||
# Encoder self-attention mask is non-causal
|
# Encoder self-attention mask is non-causal
|
||||||
attn_bias = BlockDiagonalMask.from_seqlens(
|
attn_bias = BlockDiagonalMask.from_seqlens(
|
||||||
attn_metadata.seq_lens)
|
attn_metadata.seq_lens, device=query.device)
|
||||||
|
|
||||||
# Self-attention block of decoder branch just
|
# Self-attention block of decoder branch just
|
||||||
# uses the seq_lens directly
|
# uses the seq_lens directly
|
||||||
@ -701,7 +703,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
|
|
||||||
# Decoder self-attention mask is causal
|
# Decoder self-attention mask is causal
|
||||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(
|
attn_bias = BlockDiagonalCausalMask.from_seqlens(
|
||||||
attn_metadata.seq_lens)
|
attn_metadata.seq_lens, device=query.device)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown AttentionType: %s", attn_type)
|
raise ValueError("Unknown AttentionType: %s", attn_type)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user