mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-24 04:57:52 +08:00
yapf
This commit is contained in:
parent
d16a348477
commit
4ea41d01a9
@ -8,12 +8,13 @@ _DEFAULT_BLOCK_SIZES = {
|
|||||||
"block_b": 2,
|
"block_b": 2,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def flash_attn(
|
def flash_attn(
|
||||||
q: jax.Array, # [batch, seq_len, num_heads, head_size]
|
q: jax.Array, # [batch, seq_len, num_heads, head_size]
|
||||||
k: jax.Array, # [batch, seq_len, num_heads, head_size]
|
k: jax.Array, # [batch, seq_len, num_heads, head_size]
|
||||||
v: jax.Array, # [batch, seq_len, num_heads, head_size]
|
v: jax.Array, # [batch, seq_len, num_heads, head_size]
|
||||||
sm_scale: float,
|
sm_scale: float,
|
||||||
) -> jax.Array: # [batch, seq_len, num_heads, head_size]
|
) -> jax.Array: # [batch, seq_len, num_heads, head_size]
|
||||||
return flash_attention(
|
return flash_attention(
|
||||||
q.transpose(0, 2, 1, 3),
|
q.transpose(0, 2, 1, 3),
|
||||||
k.transpose(0, 2, 1, 3),
|
k.transpose(0, 2, 1, 3),
|
||||||
@ -21,8 +22,8 @@ def flash_attn(
|
|||||||
causal=True,
|
causal=True,
|
||||||
sm_scale=sm_scale,
|
sm_scale=sm_scale,
|
||||||
block_sizes=BlockSizes(
|
block_sizes=BlockSizes(
|
||||||
min(_DEFAULT_BLOCK_SIZES["block_q"], q.shape[1]),
|
min(_DEFAULT_BLOCK_SIZES["block_q"], q.shape[1]),
|
||||||
min(_DEFAULT_BLOCK_SIZES["block_k_major"], k.shape[1]),
|
min(_DEFAULT_BLOCK_SIZES["block_k_major"], k.shape[1]),
|
||||||
min(_DEFAULT_BLOCK_SIZES["block_k"], k.shape[1]),
|
min(_DEFAULT_BLOCK_SIZES["block_k"], k.shape[1]),
|
||||||
min(_DEFAULT_BLOCK_SIZES["block_b"], q.shape[0])),
|
min(_DEFAULT_BLOCK_SIZES["block_b"], q.shape[0])),
|
||||||
).transpose(0, 2, 1, 3)
|
).transpose(0, 2, 1, 3)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user