From aabfaa08cf28333116ab167842d0d7887d8077ea Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 18 Sep 2025 14:14:03 -0700 Subject: [PATCH] fix Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu/model_runner.py | 19 ++------------- vllm/v1/worker/gpu/sampler.py | 38 ++++++++++-------------------- vllm/v1/worker/gpu/states.py | 19 +++++++++++++++ 3 files changed, 34 insertions(+), 42 deletions(-) diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 22015d2680b30..7435e3eceb69f 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -320,27 +320,15 @@ class GPUModelRunner: logits: torch.Tensor, input_batch: InputBatch, ) -> SamplerOutput: + pos = input_batch.positions[input_batch.logits_indices] sampling_metadata = self.req_states.make_sampling_metadata( - input_batch.idx_mapping_np) + input_batch.idx_mapping_np, pos) sampler_output = self.sampler( logits=logits, sampling_metadata=sampling_metadata, ) return sampler_output - def compute_prompt_logprobs( - self, - hidden_states: torch.Tensor, - input_batch: InputBatch, - ): - idx_mapping_np = input_batch.idx_mapping_np - needs_prompt_logprobs = self.req_states.needs_prompt_logprobs[ - idx_mapping_np] - if not np.any(needs_prompt_logprobs): - # Common case. - # No request in the batch needs prompt logprobs. - return None - def postprocess( self, sampler_output: SamplerOutput, @@ -387,9 +375,6 @@ class GPUModelRunner: logits = self.model.compute_logits(sample_hidden_states, None) sampler_output = self.sample(logits, input_batch) - prompt_logprobs = self.compute_prompt_logprobs(hidden_states, - input_batch) - sampled_token_ids_np, num_sampled_tokens = self.postprocess( sampler_output, input_batch) req_id_to_index = { diff --git a/vllm/v1/worker/gpu/sampler.py b/vllm/v1/worker/gpu/sampler.py index d3b9e1f780925..095aa233a4d2c 100644 --- a/vllm/v1/worker/gpu/sampler.py +++ b/vllm/v1/worker/gpu/sampler.py @@ -1,8 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Optional - import torch import torch.nn as nn import triton @@ -47,8 +44,8 @@ class Sampler(nn.Module): sampled = gumbel_sample( probs, sampling_metadata.temperature, - None, # seeds - None, # pos + sampling_metadata.seeds, + sampling_metadata.pos, ) logprobs_tensors = None @@ -163,22 +160,27 @@ def _apply_gumbel_kernel( probs_ptr, probs_stride, seeds_ptr, + pos_ptr, temp_ptr, vocab_size, BLOCK_SIZE: tl.constexpr, EPSILON: tl.constexpr, ): req_idx = tl.program_id(0) - seed = tl.load(seeds_ptr + req_idx) temp = tl.load(temp_ptr + req_idx) if temp < EPSILON: # Greedy sampling. Don't apply gumbel noise. return + seed = tl.load(seeds_ptr + req_idx) + pos = tl.load(pos_ptr + req_idx) + gumbel_seed = seed ^ (pos * 0x9E3779B9) + gumbel_seed = gumbel_seed & 0xFFFFFFFF + block_id = tl.program_id(1) r_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - q = tl.rand(seed, r_offset) + q = tl.rand(gumbel_seed, r_offset) # NOTE(woosuk): This logic makes sure q is not 0. RMAX = 0.9999999403953552 @@ -201,34 +203,20 @@ def gumbel_sample( # fp32[num_reqs] temperature: torch.Tensor, # int64[num_reqs] - seeds: Optional[torch.Tensor], + seeds: torch.Tensor, # int64[num_reqs] - pos: Optional[torch.Tensor], + pos: torch.Tensor, ) -> torch.Tensor: num_reqs = probs.shape[0] vocab_size = probs.shape[1] - if seeds is not None: - # Per-request seed. - assert pos is not None - gumbel_seeds = seeds ^ (pos * 0x9E3779B9) - else: - # Global seed. - assert pos is None - seed_dtype = torch.int64 - gumbel_seeds = torch.randint( - torch.iinfo(seed_dtype).min, - torch.iinfo(seed_dtype).max, - (num_reqs, ), - dtype=seed_dtype, - device=probs.device, - ) # Update the probs in-place. BLOCK_SIZE = 8192 _apply_gumbel_kernel[(num_reqs, triton.cdiv(vocab_size, BLOCK_SIZE))]( probs, probs.stride(0), - gumbel_seeds, + seeds, + pos, temperature, vocab_size, BLOCK_SIZE, diff --git a/vllm/v1/worker/gpu/states.py b/vllm/v1/worker/gpu/states.py index f5e4dea82c27a..f154cff0cf769 100644 --- a/vllm/v1/worker/gpu/states.py +++ b/vllm/v1/worker/gpu/states.py @@ -10,6 +10,9 @@ import torch from vllm.sampling_params import SamplingParams +_NP_INT64_MIN = np.iinfo(np.int64).min +_NP_INT64_MAX = np.iinfo(np.int64).max + @dataclass class SamplingMetadata: @@ -19,6 +22,9 @@ class SamplingMetadata: top_p: Optional[torch.Tensor] top_k: Optional[torch.Tensor] + seeds: torch.Tensor + pos: torch.Tensor + # None means no logprobs, 0 means sampled token logprobs only max_num_logprobs: Optional[int] @@ -69,6 +75,7 @@ class RequestState: self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32) # -1 means no logprobs are requested. self.num_logprobs.fill(-1) + self.seeds = np.zeros(self.max_num_reqs, dtype=np.int64) self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool) @@ -102,6 +109,12 @@ class RequestState: top_k = self.vocab_size self.top_k[req_idx] = top_k + if sampling_params.seed is not None: + seed = sampling_params.seed + else: + seed = np.random.randint(_NP_INT64_MIN, _NP_INT64_MAX) + self.seeds[req_idx] = seed + if sampling_params.logprobs is not None: num_logprobs = sampling_params.logprobs else: @@ -133,6 +146,7 @@ class RequestState: def make_sampling_metadata( self, idx_mapping: np.ndarray, + pos: torch.Tensor, ) -> SamplingMetadata: temperature = self.temperature[idx_mapping] temperature = self._copy_np_to_gpu(temperature) @@ -145,6 +159,9 @@ class RequestState: no_top_k = np.all(top_k == self.vocab_size) top_k = self._copy_np_to_gpu(top_k) if not no_top_k else None + seeds = self.seeds[idx_mapping] + seeds = self._copy_np_to_gpu(seeds) + num_logprobs = self.num_logprobs[idx_mapping] max_num_logprobs = np.max(num_logprobs) if max_num_logprobs == -1: @@ -154,6 +171,8 @@ class RequestState: temperature=temperature, top_p=top_p, top_k=top_k, + seeds=seeds, + pos=pos, max_num_logprobs=max_num_logprobs, )