mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-27 14:13:46 +08:00
Add attn_mask
This commit is contained in:
parent
0fb07c08d0
commit
4880de35d2
@ -163,20 +163,25 @@ class Attention(nn.Module):
|
||||
value_proj = jnp.repeat(value_proj, self.num_heads, axis=-2)
|
||||
key_proj = jnp.repeat(key_proj, self.num_heads, axis=-2)
|
||||
|
||||
# output = flash_attn(
|
||||
# query_proj,
|
||||
# key_proj,
|
||||
# value_proj,
|
||||
# self.sm_scale,
|
||||
# )
|
||||
query_scaled = query_proj * self.sm_scale
|
||||
logits = jnp.einsum('BTNH,BSNH->BTNS', query_scaled, key_proj)
|
||||
padded_logits = logits
|
||||
# padded_logits = jnp.where(
|
||||
# (jnp.expand_dims(attn_mask, -2)), logits, K_MASK
|
||||
# )
|
||||
probs = jax.nn.softmax(padded_logits, axis=-1).astype(key_proj.dtype)
|
||||
output = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
|
||||
if False:
|
||||
# FIXME(woosuk)
|
||||
output = flash_attn(
|
||||
query_proj,
|
||||
key_proj,
|
||||
value_proj,
|
||||
self.sm_scale,
|
||||
)
|
||||
else:
|
||||
seq_len = query_proj.shape[1]
|
||||
attn_mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool_))
|
||||
|
||||
query_scaled = query_proj * self.sm_scale
|
||||
logits = jnp.einsum('BTNH,BSNH->BTNS', query_scaled, key_proj)
|
||||
masked_logits = jnp.where(
|
||||
(jnp.expand_dims(attn_mask, -2)), logits, K_MASK
|
||||
)
|
||||
probs = jax.nn.softmax(masked_logits, axis=-1).astype(key_proj.dtype)
|
||||
output = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
|
||||
else:
|
||||
# Decode attention.
|
||||
output = paged_attn(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user