This commit is contained in:
Woosuk Kwon 2024-04-26 05:27:38 +00:00
parent d16a348477
commit 4ea41d01a9

View File

@ -8,12 +8,13 @@ _DEFAULT_BLOCK_SIZES = {
"block_b": 2,
}
def flash_attn(
q: jax.Array, # [batch, seq_len, num_heads, head_size]
k: jax.Array, # [batch, seq_len, num_heads, head_size]
v: jax.Array, # [batch, seq_len, num_heads, head_size]
sm_scale: float,
) -> jax.Array: # [batch, seq_len, num_heads, head_size]
) -> jax.Array: # [batch, seq_len, num_heads, head_size]
return flash_attention(
q.transpose(0, 2, 1, 3),
k.transpose(0, 2, 1, 3),
@ -21,8 +22,8 @@ def flash_attn(
causal=True,
sm_scale=sm_scale,
block_sizes=BlockSizes(
min(_DEFAULT_BLOCK_SIZES["block_q"], q.shape[1]),
min(_DEFAULT_BLOCK_SIZES["block_k_major"], k.shape[1]),
min(_DEFAULT_BLOCK_SIZES["block_k"], k.shape[1]),
min(_DEFAULT_BLOCK_SIZES["block_b"], q.shape[0])),
min(_DEFAULT_BLOCK_SIZES["block_q"], q.shape[1]),
min(_DEFAULT_BLOCK_SIZES["block_k_major"], k.shape[1]),
min(_DEFAULT_BLOCK_SIZES["block_k"], k.shape[1]),
min(_DEFAULT_BLOCK_SIZES["block_b"], q.shape[0])),
).transpose(0, 2, 1, 3)