From 4880de35d2455ce11b4a0f245e3d62d4095bc8f9 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 17 Apr 2024 18:12:20 +0000 Subject: [PATCH] Add attn_mask --- vllm/model_executor/models/jax/gemma.py | 33 ++++++++++++++----------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/models/jax/gemma.py b/vllm/model_executor/models/jax/gemma.py index 161b609c417ef..a244aaa9c94ff 100644 --- a/vllm/model_executor/models/jax/gemma.py +++ b/vllm/model_executor/models/jax/gemma.py @@ -163,20 +163,25 @@ class Attention(nn.Module): value_proj = jnp.repeat(value_proj, self.num_heads, axis=-2) key_proj = jnp.repeat(key_proj, self.num_heads, axis=-2) - # output = flash_attn( - # query_proj, - # key_proj, - # value_proj, - # self.sm_scale, - # ) - query_scaled = query_proj * self.sm_scale - logits = jnp.einsum('BTNH,BSNH->BTNS', query_scaled, key_proj) - padded_logits = logits - # padded_logits = jnp.where( - # (jnp.expand_dims(attn_mask, -2)), logits, K_MASK - # ) - probs = jax.nn.softmax(padded_logits, axis=-1).astype(key_proj.dtype) - output = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj) + if False: + # FIXME(woosuk) + output = flash_attn( + query_proj, + key_proj, + value_proj, + self.sm_scale, + ) + else: + seq_len = query_proj.shape[1] + attn_mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool_)) + + query_scaled = query_proj * self.sm_scale + logits = jnp.einsum('BTNH,BSNH->BTNS', query_scaled, key_proj) + masked_logits = jnp.where( + (jnp.expand_dims(attn_mask, -2)), logits, K_MASK + ) + probs = jax.nn.softmax(masked_logits, axis=-1).astype(key_proj.dtype) + output = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj) else: # Decode attention. output = paged_attn(