diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index c485d3779d9a6..bf330c7770d12 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -876,7 +876,7 @@ class JambaForCausalLM(nn.Module): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits