diff --git a/vllm/v1/worker/gpu/sampler.py b/vllm/v1/worker/gpu/sampler.py index 55f98ca6bb6a..499e9d3b1538 100644 --- a/vllm/v1/worker/gpu/sampler.py +++ b/vllm/v1/worker/gpu/sampler.py @@ -3,10 +3,9 @@ from collections.abc import Callable import torch -import triton -import triton.language as tl from vllm.config.model import LogprobsMode +from vllm.triton_utils import tl, triton from vllm.v1.outputs import LogprobsTensors, SamplerOutput from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p from vllm.v1.worker.gpu.states import SamplingMetadata @@ -78,7 +77,10 @@ class Sampler: @triton.jit def _gumbel_sample_kernel( - sampled_ptr, + local_argmax_ptr, + local_argmax_stride, + local_max_ptr, + local_max_stride, logits_ptr, logits_stride, seeds_ptr, @@ -88,40 +90,21 @@ def _gumbel_sample_kernel( BLOCK_SIZE: tl.constexpr, ): req_idx = tl.program_id(0) + block_idx = tl.program_id(1) + block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = block < vocab_size + logits = tl.load( + logits_ptr + req_idx * logits_stride + block, + mask=mask, + other=float("-inf"), + ) + is_greedy = tl.load(is_greedy_ptr + req_idx) - - if is_greedy: - # Greedy sampling. Don't apply gumbel noise. - max_val = float("-inf") - max_idx = 0 - for i in range(0, vocab_size, BLOCK_SIZE): - block = i + tl.arange(0, BLOCK_SIZE) - mask = block < vocab_size - logits = tl.load( - logits_ptr + req_idx * logits_stride + block, - mask=mask, - other=float("-inf"), - ) - - idx = tl.argmax(logits, axis=0) - value = tl.max(logits, axis=0) - is_greater = value > max_val - max_val = tl.where(is_greater, value, max_val) - max_idx = tl.where(is_greater, i + idx, max_idx) - tl.store(sampled_ptr + req_idx, max_idx) - return - - # Random sampling. - # Calculate gumbel seed. - seed = tl.load(seeds_ptr + req_idx) - pos = tl.load(pos_ptr + req_idx) - gumbel_seed = tl.randint(seed, pos) - - max_val = float("-inf") - max_idx = 0 - for i in range(0, vocab_size, BLOCK_SIZE): - block = i + tl.arange(0, BLOCK_SIZE) - mask = block < vocab_size + if not is_greedy: + # Calculate the seed for gumbel noise. + seed = tl.load(seeds_ptr + req_idx) + pos = tl.load(pos_ptr + req_idx) + gumbel_seed = tl.randint(seed, pos) # Generate gumbel noise. r = tl.rand(gumbel_seed, block).to(tl.float64) @@ -129,16 +112,13 @@ def _gumbel_sample_kernel( gumbel_noise = gumbel_noise.to(tl.float32) # Apply gumbel noise. - logits = tl.load(logits_ptr + req_idx * logits_stride + block, mask=mask) logits = tl.where(mask, logits + gumbel_noise, float("-inf")) - # Argmax to get the sampled token. - idx = tl.argmax(logits, axis=0) - value = tl.max(logits, axis=0) - is_greater = value > max_val - max_val = tl.where(is_greater, value, max_val) - max_idx = tl.where(is_greater, i + idx, max_idx) - tl.store(sampled_ptr + req_idx, max_idx) + idx = tl.argmax(logits, axis=0) + token_id = block_idx * BLOCK_SIZE + idx + value = tl.max(logits, axis=0) + tl.store(local_argmax_ptr + req_idx * local_argmax_stride + block_idx, token_id) + tl.store(local_max_ptr + req_idx * local_max_stride + block_idx, value) def gumbel_sample( @@ -148,23 +128,36 @@ def gumbel_sample( pos: torch.Tensor, # [num_reqs] ) -> torch.Tensor: num_reqs, vocab_size = logits.shape - # NOTE(woosuk): Use int64 for later indexing. - sampled = torch.empty( + BLOCK_SIZE = 1024 + num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE) + local_argmax = torch.empty( num_reqs, + num_blocks, dtype=torch.int64, device=logits.device, ) - _gumbel_sample_kernel[(num_reqs,)]( - sampled, + local_max = torch.empty( + num_reqs, + num_blocks, + dtype=torch.float32, + device=logits.device, + ) + _gumbel_sample_kernel[(num_reqs, num_blocks)]( + local_argmax, + local_argmax.stride(0), + local_max, + local_max.stride(0), logits, logits.stride(0), seed, pos, is_greedy, vocab_size, - num_warps=8, - BLOCK_SIZE=16384, # type: ignore + BLOCK_SIZE=BLOCK_SIZE, ) + # NOTE(woosuk): Use int64 for later indexing. + max_block_idx = local_max.argmax(dim=-1, keepdim=True) + sampled = local_argmax.gather(dim=-1, index=max_block_idx).view(-1) return sampled