mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-11 15:29:08 +08:00
Add attn_mask
This commit is contained in:
parent
0fb07c08d0
commit
4880de35d2
@ -163,20 +163,25 @@ class Attention(nn.Module):
|
|||||||
value_proj = jnp.repeat(value_proj, self.num_heads, axis=-2)
|
value_proj = jnp.repeat(value_proj, self.num_heads, axis=-2)
|
||||||
key_proj = jnp.repeat(key_proj, self.num_heads, axis=-2)
|
key_proj = jnp.repeat(key_proj, self.num_heads, axis=-2)
|
||||||
|
|
||||||
# output = flash_attn(
|
if False:
|
||||||
# query_proj,
|
# FIXME(woosuk)
|
||||||
# key_proj,
|
output = flash_attn(
|
||||||
# value_proj,
|
query_proj,
|
||||||
# self.sm_scale,
|
key_proj,
|
||||||
# )
|
value_proj,
|
||||||
query_scaled = query_proj * self.sm_scale
|
self.sm_scale,
|
||||||
logits = jnp.einsum('BTNH,BSNH->BTNS', query_scaled, key_proj)
|
)
|
||||||
padded_logits = logits
|
else:
|
||||||
# padded_logits = jnp.where(
|
seq_len = query_proj.shape[1]
|
||||||
# (jnp.expand_dims(attn_mask, -2)), logits, K_MASK
|
attn_mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool_))
|
||||||
# )
|
|
||||||
probs = jax.nn.softmax(padded_logits, axis=-1).astype(key_proj.dtype)
|
query_scaled = query_proj * self.sm_scale
|
||||||
output = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
|
logits = jnp.einsum('BTNH,BSNH->BTNS', query_scaled, key_proj)
|
||||||
|
masked_logits = jnp.where(
|
||||||
|
(jnp.expand_dims(attn_mask, -2)), logits, K_MASK
|
||||||
|
)
|
||||||
|
probs = jax.nn.softmax(masked_logits, axis=-1).astype(key_proj.dtype)
|
||||||
|
output = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
|
||||||
else:
|
else:
|
||||||
# Decode attention.
|
# Decode attention.
|
||||||
output = paged_attn(
|
output = paged_attn(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user