mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:35:17 +08:00
[Model Runner V2] Optimize Gumbel Sampling Kernel (#29210)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
c6fa3895e9
commit
e9af6ba62a
@ -3,10 +3,9 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.config.model import LogprobsMode
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.v1.worker.gpu.states import SamplingMetadata
|
||||
@ -78,7 +77,10 @@ class Sampler:
|
||||
|
||||
@triton.jit
|
||||
def _gumbel_sample_kernel(
|
||||
sampled_ptr,
|
||||
local_argmax_ptr,
|
||||
local_argmax_stride,
|
||||
local_max_ptr,
|
||||
local_max_stride,
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
seeds_ptr,
|
||||
@ -88,40 +90,21 @@ def _gumbel_sample_kernel(
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
block_idx = tl.program_id(1)
|
||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(
|
||||
logits_ptr + req_idx * logits_stride + block,
|
||||
mask=mask,
|
||||
other=float("-inf"),
|
||||
)
|
||||
|
||||
is_greedy = tl.load(is_greedy_ptr + req_idx)
|
||||
|
||||
if is_greedy:
|
||||
# Greedy sampling. Don't apply gumbel noise.
|
||||
max_val = float("-inf")
|
||||
max_idx = 0
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(
|
||||
logits_ptr + req_idx * logits_stride + block,
|
||||
mask=mask,
|
||||
other=float("-inf"),
|
||||
)
|
||||
|
||||
idx = tl.argmax(logits, axis=0)
|
||||
value = tl.max(logits, axis=0)
|
||||
is_greater = value > max_val
|
||||
max_val = tl.where(is_greater, value, max_val)
|
||||
max_idx = tl.where(is_greater, i + idx, max_idx)
|
||||
tl.store(sampled_ptr + req_idx, max_idx)
|
||||
return
|
||||
|
||||
# Random sampling.
|
||||
# Calculate gumbel seed.
|
||||
seed = tl.load(seeds_ptr + req_idx)
|
||||
pos = tl.load(pos_ptr + req_idx)
|
||||
gumbel_seed = tl.randint(seed, pos)
|
||||
|
||||
max_val = float("-inf")
|
||||
max_idx = 0
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
if not is_greedy:
|
||||
# Calculate the seed for gumbel noise.
|
||||
seed = tl.load(seeds_ptr + req_idx)
|
||||
pos = tl.load(pos_ptr + req_idx)
|
||||
gumbel_seed = tl.randint(seed, pos)
|
||||
|
||||
# Generate gumbel noise.
|
||||
r = tl.rand(gumbel_seed, block).to(tl.float64)
|
||||
@ -129,16 +112,13 @@ def _gumbel_sample_kernel(
|
||||
gumbel_noise = gumbel_noise.to(tl.float32)
|
||||
|
||||
# Apply gumbel noise.
|
||||
logits = tl.load(logits_ptr + req_idx * logits_stride + block, mask=mask)
|
||||
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
|
||||
|
||||
# Argmax to get the sampled token.
|
||||
idx = tl.argmax(logits, axis=0)
|
||||
value = tl.max(logits, axis=0)
|
||||
is_greater = value > max_val
|
||||
max_val = tl.where(is_greater, value, max_val)
|
||||
max_idx = tl.where(is_greater, i + idx, max_idx)
|
||||
tl.store(sampled_ptr + req_idx, max_idx)
|
||||
idx = tl.argmax(logits, axis=0)
|
||||
token_id = block_idx * BLOCK_SIZE + idx
|
||||
value = tl.max(logits, axis=0)
|
||||
tl.store(local_argmax_ptr + req_idx * local_argmax_stride + block_idx, token_id)
|
||||
tl.store(local_max_ptr + req_idx * local_max_stride + block_idx, value)
|
||||
|
||||
|
||||
def gumbel_sample(
|
||||
@ -148,23 +128,36 @@ def gumbel_sample(
|
||||
pos: torch.Tensor, # [num_reqs]
|
||||
) -> torch.Tensor:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
# NOTE(woosuk): Use int64 for later indexing.
|
||||
sampled = torch.empty(
|
||||
BLOCK_SIZE = 1024
|
||||
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
||||
local_argmax = torch.empty(
|
||||
num_reqs,
|
||||
num_blocks,
|
||||
dtype=torch.int64,
|
||||
device=logits.device,
|
||||
)
|
||||
_gumbel_sample_kernel[(num_reqs,)](
|
||||
sampled,
|
||||
local_max = torch.empty(
|
||||
num_reqs,
|
||||
num_blocks,
|
||||
dtype=torch.float32,
|
||||
device=logits.device,
|
||||
)
|
||||
_gumbel_sample_kernel[(num_reqs, num_blocks)](
|
||||
local_argmax,
|
||||
local_argmax.stride(0),
|
||||
local_max,
|
||||
local_max.stride(0),
|
||||
logits,
|
||||
logits.stride(0),
|
||||
seed,
|
||||
pos,
|
||||
is_greedy,
|
||||
vocab_size,
|
||||
num_warps=8,
|
||||
BLOCK_SIZE=16384, # type: ignore
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
# NOTE(woosuk): Use int64 for later indexing.
|
||||
max_block_idx = local_max.argmax(dim=-1, keepdim=True)
|
||||
sampled = local_argmax.gather(dim=-1, index=max_block_idx).view(-1)
|
||||
return sampled
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user