This commit is contained in:
Woosuk Kwon 2024-04-17 18:08:33 +00:00
parent e4377dd698
commit 0fb07c08d0

View File

@ -24,5 +24,5 @@ def flash_attn(
min(_DEFAULT_BLOCK_SIZES["block_q"], q.shape[1]),
min(_DEFAULT_BLOCK_SIZES["block_k_major"], k.shape[1]),
min(_DEFAULT_BLOCK_SIZES["block_k"], k.shape[1]),
min(_DEFAULT_BLOCK_SIZES["block_b"], q.shape[0]))
min(_DEFAULT_BLOCK_SIZES["block_b"], q.shape[0])),
).transpose(0, 2, 1, 3)