Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-09-18 14:14:03 -07:00
parent bc6463ac97
commit aabfaa08cf
3 changed files with 34 additions and 42 deletions

View File

@ -320,27 +320,15 @@ class GPUModelRunner:
logits: torch.Tensor,
input_batch: InputBatch,
) -> SamplerOutput:
pos = input_batch.positions[input_batch.logits_indices]
sampling_metadata = self.req_states.make_sampling_metadata(
input_batch.idx_mapping_np)
input_batch.idx_mapping_np, pos)
sampler_output = self.sampler(
logits=logits,
sampling_metadata=sampling_metadata,
)
return sampler_output
def compute_prompt_logprobs(
self,
hidden_states: torch.Tensor,
input_batch: InputBatch,
):
idx_mapping_np = input_batch.idx_mapping_np
needs_prompt_logprobs = self.req_states.needs_prompt_logprobs[
idx_mapping_np]
if not np.any(needs_prompt_logprobs):
# Common case.
# No request in the batch needs prompt logprobs.
return None
def postprocess(
self,
sampler_output: SamplerOutput,
@ -387,9 +375,6 @@ class GPUModelRunner:
logits = self.model.compute_logits(sample_hidden_states, None)
sampler_output = self.sample(logits, input_batch)
prompt_logprobs = self.compute_prompt_logprobs(hidden_states,
input_batch)
sampled_token_ids_np, num_sampled_tokens = self.postprocess(
sampler_output, input_batch)
req_id_to_index = {

View File

@ -1,8 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
import torch.nn as nn
import triton
@ -47,8 +44,8 @@ class Sampler(nn.Module):
sampled = gumbel_sample(
probs,
sampling_metadata.temperature,
None, # seeds
None, # pos
sampling_metadata.seeds,
sampling_metadata.pos,
)
logprobs_tensors = None
@ -163,22 +160,27 @@ def _apply_gumbel_kernel(
probs_ptr,
probs_stride,
seeds_ptr,
pos_ptr,
temp_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
EPSILON: tl.constexpr,
):
req_idx = tl.program_id(0)
seed = tl.load(seeds_ptr + req_idx)
temp = tl.load(temp_ptr + req_idx)
if temp < EPSILON:
# Greedy sampling. Don't apply gumbel noise.
return
seed = tl.load(seeds_ptr + req_idx)
pos = tl.load(pos_ptr + req_idx)
gumbel_seed = seed ^ (pos * 0x9E3779B9)
gumbel_seed = gumbel_seed & 0xFFFFFFFF
block_id = tl.program_id(1)
r_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
q = tl.rand(seed, r_offset)
q = tl.rand(gumbel_seed, r_offset)
# NOTE(woosuk): This logic makes sure q is not 0.
RMAX = 0.9999999403953552
@ -201,34 +203,20 @@ def gumbel_sample(
# fp32[num_reqs]
temperature: torch.Tensor,
# int64[num_reqs]
seeds: Optional[torch.Tensor],
seeds: torch.Tensor,
# int64[num_reqs]
pos: Optional[torch.Tensor],
pos: torch.Tensor,
) -> torch.Tensor:
num_reqs = probs.shape[0]
vocab_size = probs.shape[1]
if seeds is not None:
# Per-request seed.
assert pos is not None
gumbel_seeds = seeds ^ (pos * 0x9E3779B9)
else:
# Global seed.
assert pos is None
seed_dtype = torch.int64
gumbel_seeds = torch.randint(
torch.iinfo(seed_dtype).min,
torch.iinfo(seed_dtype).max,
(num_reqs, ),
dtype=seed_dtype,
device=probs.device,
)
# Update the probs in-place.
BLOCK_SIZE = 8192
_apply_gumbel_kernel[(num_reqs, triton.cdiv(vocab_size, BLOCK_SIZE))](
probs,
probs.stride(0),
gumbel_seeds,
seeds,
pos,
temperature,
vocab_size,
BLOCK_SIZE,

View File

@ -10,6 +10,9 @@ import torch
from vllm.sampling_params import SamplingParams
_NP_INT64_MIN = np.iinfo(np.int64).min
_NP_INT64_MAX = np.iinfo(np.int64).max
@dataclass
class SamplingMetadata:
@ -19,6 +22,9 @@ class SamplingMetadata:
top_p: Optional[torch.Tensor]
top_k: Optional[torch.Tensor]
seeds: torch.Tensor
pos: torch.Tensor
# None means no logprobs, 0 means sampled token logprobs only
max_num_logprobs: Optional[int]
@ -69,6 +75,7 @@ class RequestState:
self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32)
# -1 means no logprobs are requested.
self.num_logprobs.fill(-1)
self.seeds = np.zeros(self.max_num_reqs, dtype=np.int64)
self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
@ -102,6 +109,12 @@ class RequestState:
top_k = self.vocab_size
self.top_k[req_idx] = top_k
if sampling_params.seed is not None:
seed = sampling_params.seed
else:
seed = np.random.randint(_NP_INT64_MIN, _NP_INT64_MAX)
self.seeds[req_idx] = seed
if sampling_params.logprobs is not None:
num_logprobs = sampling_params.logprobs
else:
@ -133,6 +146,7 @@ class RequestState:
def make_sampling_metadata(
self,
idx_mapping: np.ndarray,
pos: torch.Tensor,
) -> SamplingMetadata:
temperature = self.temperature[idx_mapping]
temperature = self._copy_np_to_gpu(temperature)
@ -145,6 +159,9 @@ class RequestState:
no_top_k = np.all(top_k == self.vocab_size)
top_k = self._copy_np_to_gpu(top_k) if not no_top_k else None
seeds = self.seeds[idx_mapping]
seeds = self._copy_np_to_gpu(seeds)
num_logprobs = self.num_logprobs[idx_mapping]
max_num_logprobs = np.max(num_logprobs)
if max_num_logprobs == -1:
@ -154,6 +171,8 @@ class RequestState:
temperature=temperature,
top_p=top_p,
top_k=top_k,
seeds=seeds,
pos=pos,
max_num_logprobs=max_num_logprobs,
)