Fix logit indices

This commit is contained in:
Woosuk Kwon 2024-04-17 18:01:43 +00:00
parent cedb67028a
commit 8888d1c474

View File

@ -311,7 +311,7 @@ class Transformer(nn.Module):
)
kv_caches[i] = layer_cache
x = self.final_norm(x)
x = x.reshape(-1, x.shape[-1])
hidden_states = x[logits_indices]
logits = self.embedder.decode(hidden_states)
return logits, kv_caches