diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 8c5ad14c40758..cf44f7d52b0b5 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -316,9 +316,12 @@ class GPUModelRunner: def sample( self, - logits: torch.Tensor, + hidden_states: torch.Tensor, input_batch: InputBatch, ) -> SamplerOutput: + # TODO(woosuk): Support DP sampler + CUDA graphs. + sample_hidden_states = hidden_states[input_batch.logits_indices] + logits = self.model.compute_logits(sample_hidden_states, None) pos = input_batch.positions[input_batch.logits_indices] sampling_metadata = self.req_states.make_sampling_metadata( input_batch.idx_mapping_np, pos) @@ -369,11 +372,8 @@ class GPUModelRunner: positions=input_batch.positions[:num_tokens], ) - # Compute logits to sample next tokens. - sample_hidden_states = hidden_states[input_batch.logits_indices] - logits = self.model.compute_logits(sample_hidden_states, None) + sampler_output = self.sample(hidden_states, input_batch) - sampler_output = self.sample(logits, input_batch) sampled_token_ids_np, num_sampled_tokens = self.postprocess( sampler_output, input_batch) logprobs = None