[Model Runner V2] Optimize Gumbel Sampling Kernel (#29210)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-11-21 15:52:28 -08:00 committed by GitHub
parent c6fa3895e9
commit e9af6ba62a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,10 +3,9 @@
from collections.abc import Callable
import torch
import triton
import triton.language as tl
from vllm.config.model import LogprobsMode
from vllm.triton_utils import tl, triton
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.worker.gpu.states import SamplingMetadata
@ -78,7 +77,10 @@ class Sampler:
@triton.jit
def _gumbel_sample_kernel(
sampled_ptr,
local_argmax_ptr,
local_argmax_stride,
local_max_ptr,
local_max_stride,
logits_ptr,
logits_stride,
seeds_ptr,
@ -88,40 +90,21 @@ def _gumbel_sample_kernel(
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
block_idx = tl.program_id(1)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + req_idx * logits_stride + block,
mask=mask,
other=float("-inf"),
)
is_greedy = tl.load(is_greedy_ptr + req_idx)
if is_greedy:
# Greedy sampling. Don't apply gumbel noise.
max_val = float("-inf")
max_idx = 0
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + req_idx * logits_stride + block,
mask=mask,
other=float("-inf"),
)
idx = tl.argmax(logits, axis=0)
value = tl.max(logits, axis=0)
is_greater = value > max_val
max_val = tl.where(is_greater, value, max_val)
max_idx = tl.where(is_greater, i + idx, max_idx)
tl.store(sampled_ptr + req_idx, max_idx)
return
# Random sampling.
# Calculate gumbel seed.
seed = tl.load(seeds_ptr + req_idx)
pos = tl.load(pos_ptr + req_idx)
gumbel_seed = tl.randint(seed, pos)
max_val = float("-inf")
max_idx = 0
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
if not is_greedy:
# Calculate the seed for gumbel noise.
seed = tl.load(seeds_ptr + req_idx)
pos = tl.load(pos_ptr + req_idx)
gumbel_seed = tl.randint(seed, pos)
# Generate gumbel noise.
r = tl.rand(gumbel_seed, block).to(tl.float64)
@ -129,16 +112,13 @@ def _gumbel_sample_kernel(
gumbel_noise = gumbel_noise.to(tl.float32)
# Apply gumbel noise.
logits = tl.load(logits_ptr + req_idx * logits_stride + block, mask=mask)
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
# Argmax to get the sampled token.
idx = tl.argmax(logits, axis=0)
value = tl.max(logits, axis=0)
is_greater = value > max_val
max_val = tl.where(is_greater, value, max_val)
max_idx = tl.where(is_greater, i + idx, max_idx)
tl.store(sampled_ptr + req_idx, max_idx)
idx = tl.argmax(logits, axis=0)
token_id = block_idx * BLOCK_SIZE + idx
value = tl.max(logits, axis=0)
tl.store(local_argmax_ptr + req_idx * local_argmax_stride + block_idx, token_id)
tl.store(local_max_ptr + req_idx * local_max_stride + block_idx, value)
def gumbel_sample(
@ -148,23 +128,36 @@ def gumbel_sample(
pos: torch.Tensor, # [num_reqs]
) -> torch.Tensor:
num_reqs, vocab_size = logits.shape
# NOTE(woosuk): Use int64 for later indexing.
sampled = torch.empty(
BLOCK_SIZE = 1024
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
local_argmax = torch.empty(
num_reqs,
num_blocks,
dtype=torch.int64,
device=logits.device,
)
_gumbel_sample_kernel[(num_reqs,)](
sampled,
local_max = torch.empty(
num_reqs,
num_blocks,
dtype=torch.float32,
device=logits.device,
)
_gumbel_sample_kernel[(num_reqs, num_blocks)](
local_argmax,
local_argmax.stride(0),
local_max,
local_max.stride(0),
logits,
logits.stride(0),
seed,
pos,
is_greedy,
vocab_size,
num_warps=8,
BLOCK_SIZE=16384, # type: ignore
BLOCK_SIZE=BLOCK_SIZE,
)
# NOTE(woosuk): Use int64 for later indexing.
max_block_idx = local_max.argmax(dim=-1, keepdim=True)
sampled = local_argmax.gather(dim=-1, index=max_block_idx).view(-1)
return sampled