diff --git a/tests/kernels/attention/test_attention.py b/tests/kernels/attention/test_attention.py index 7269d19183bf2..2e0b4efebfdb1 100644 --- a/tests/kernels/attention/test_attention.py +++ b/tests/kernels/attention/test_attention.py @@ -450,7 +450,8 @@ def test_multi_query_kv_attention( start += seq_len # xformers.AttentionBias to Tensor for use in reference impl. alibi_bias = [ - b.materialize(b.shape, device=device).squeeze() for b in attn_bias + b.materialize((1, num_query_heads, i, i), device=device).squeeze() + for b, i in zip(attn_bias, seq_lens) ] else: attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)