From 8888d1c4741a383c3a0beb307d45f74ed6e5403c Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 17 Apr 2024 18:01:43 +0000 Subject: [PATCH] Fix logit indices --- vllm/model_executor/models/jax/gemma.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/jax/gemma.py b/vllm/model_executor/models/jax/gemma.py index bb85d9ed6ace3..161b609c417ef 100644 --- a/vllm/model_executor/models/jax/gemma.py +++ b/vllm/model_executor/models/jax/gemma.py @@ -311,7 +311,7 @@ class Transformer(nn.Module): ) kv_caches[i] = layer_cache x = self.final_norm(x) - + x = x.reshape(-1, x.shape[-1]) hidden_states = x[logits_indices] logits = self.embedder.decode(hidden_states) return logits, kv_caches