diff --git a/vllm/model_executor/models/jax/gemma.py b/vllm/model_executor/models/jax/gemma.py index bb85d9ed6ace3..161b609c417ef 100644 --- a/vllm/model_executor/models/jax/gemma.py +++ b/vllm/model_executor/models/jax/gemma.py @@ -311,7 +311,7 @@ class Transformer(nn.Module): ) kv_caches[i] = layer_cache x = self.final_norm(x) - + x = x.reshape(-1, x.shape[-1]) hidden_states = x[logits_indices] logits = self.embedder.decode(hidden_states) return logits, kv_caches