mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-12 19:09:46 +08:00
[Model Runner V2] Optimize Gumbel Sampling Kernel (#29210)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
c6fa3895e9
commit
e9af6ba62a
@ -3,10 +3,9 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
from vllm.config.model import LogprobsMode
|
from vllm.config.model import LogprobsMode
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||||
from vllm.v1.worker.gpu.states import SamplingMetadata
|
from vllm.v1.worker.gpu.states import SamplingMetadata
|
||||||
@ -78,7 +77,10 @@ class Sampler:
|
|||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _gumbel_sample_kernel(
|
def _gumbel_sample_kernel(
|
||||||
sampled_ptr,
|
local_argmax_ptr,
|
||||||
|
local_argmax_stride,
|
||||||
|
local_max_ptr,
|
||||||
|
local_max_stride,
|
||||||
logits_ptr,
|
logits_ptr,
|
||||||
logits_stride,
|
logits_stride,
|
||||||
seeds_ptr,
|
seeds_ptr,
|
||||||
@ -88,14 +90,8 @@ def _gumbel_sample_kernel(
|
|||||||
BLOCK_SIZE: tl.constexpr,
|
BLOCK_SIZE: tl.constexpr,
|
||||||
):
|
):
|
||||||
req_idx = tl.program_id(0)
|
req_idx = tl.program_id(0)
|
||||||
is_greedy = tl.load(is_greedy_ptr + req_idx)
|
block_idx = tl.program_id(1)
|
||||||
|
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
if is_greedy:
|
|
||||||
# Greedy sampling. Don't apply gumbel noise.
|
|
||||||
max_val = float("-inf")
|
|
||||||
max_idx = 0
|
|
||||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
|
||||||
block = i + tl.arange(0, BLOCK_SIZE)
|
|
||||||
mask = block < vocab_size
|
mask = block < vocab_size
|
||||||
logits = tl.load(
|
logits = tl.load(
|
||||||
logits_ptr + req_idx * logits_stride + block,
|
logits_ptr + req_idx * logits_stride + block,
|
||||||
@ -103,42 +99,26 @@ def _gumbel_sample_kernel(
|
|||||||
other=float("-inf"),
|
other=float("-inf"),
|
||||||
)
|
)
|
||||||
|
|
||||||
idx = tl.argmax(logits, axis=0)
|
is_greedy = tl.load(is_greedy_ptr + req_idx)
|
||||||
value = tl.max(logits, axis=0)
|
if not is_greedy:
|
||||||
is_greater = value > max_val
|
# Calculate the seed for gumbel noise.
|
||||||
max_val = tl.where(is_greater, value, max_val)
|
|
||||||
max_idx = tl.where(is_greater, i + idx, max_idx)
|
|
||||||
tl.store(sampled_ptr + req_idx, max_idx)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Random sampling.
|
|
||||||
# Calculate gumbel seed.
|
|
||||||
seed = tl.load(seeds_ptr + req_idx)
|
seed = tl.load(seeds_ptr + req_idx)
|
||||||
pos = tl.load(pos_ptr + req_idx)
|
pos = tl.load(pos_ptr + req_idx)
|
||||||
gumbel_seed = tl.randint(seed, pos)
|
gumbel_seed = tl.randint(seed, pos)
|
||||||
|
|
||||||
max_val = float("-inf")
|
|
||||||
max_idx = 0
|
|
||||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
|
||||||
block = i + tl.arange(0, BLOCK_SIZE)
|
|
||||||
mask = block < vocab_size
|
|
||||||
|
|
||||||
# Generate gumbel noise.
|
# Generate gumbel noise.
|
||||||
r = tl.rand(gumbel_seed, block).to(tl.float64)
|
r = tl.rand(gumbel_seed, block).to(tl.float64)
|
||||||
gumbel_noise = -tl.log(-tl.log(r + 1e-20) + 1e-20)
|
gumbel_noise = -tl.log(-tl.log(r + 1e-20) + 1e-20)
|
||||||
gumbel_noise = gumbel_noise.to(tl.float32)
|
gumbel_noise = gumbel_noise.to(tl.float32)
|
||||||
|
|
||||||
# Apply gumbel noise.
|
# Apply gumbel noise.
|
||||||
logits = tl.load(logits_ptr + req_idx * logits_stride + block, mask=mask)
|
|
||||||
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
|
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
|
||||||
|
|
||||||
# Argmax to get the sampled token.
|
|
||||||
idx = tl.argmax(logits, axis=0)
|
idx = tl.argmax(logits, axis=0)
|
||||||
|
token_id = block_idx * BLOCK_SIZE + idx
|
||||||
value = tl.max(logits, axis=0)
|
value = tl.max(logits, axis=0)
|
||||||
is_greater = value > max_val
|
tl.store(local_argmax_ptr + req_idx * local_argmax_stride + block_idx, token_id)
|
||||||
max_val = tl.where(is_greater, value, max_val)
|
tl.store(local_max_ptr + req_idx * local_max_stride + block_idx, value)
|
||||||
max_idx = tl.where(is_greater, i + idx, max_idx)
|
|
||||||
tl.store(sampled_ptr + req_idx, max_idx)
|
|
||||||
|
|
||||||
|
|
||||||
def gumbel_sample(
|
def gumbel_sample(
|
||||||
@ -148,23 +128,36 @@ def gumbel_sample(
|
|||||||
pos: torch.Tensor, # [num_reqs]
|
pos: torch.Tensor, # [num_reqs]
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
num_reqs, vocab_size = logits.shape
|
num_reqs, vocab_size = logits.shape
|
||||||
# NOTE(woosuk): Use int64 for later indexing.
|
BLOCK_SIZE = 1024
|
||||||
sampled = torch.empty(
|
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
||||||
|
local_argmax = torch.empty(
|
||||||
num_reqs,
|
num_reqs,
|
||||||
|
num_blocks,
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
device=logits.device,
|
device=logits.device,
|
||||||
)
|
)
|
||||||
_gumbel_sample_kernel[(num_reqs,)](
|
local_max = torch.empty(
|
||||||
sampled,
|
num_reqs,
|
||||||
|
num_blocks,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=logits.device,
|
||||||
|
)
|
||||||
|
_gumbel_sample_kernel[(num_reqs, num_blocks)](
|
||||||
|
local_argmax,
|
||||||
|
local_argmax.stride(0),
|
||||||
|
local_max,
|
||||||
|
local_max.stride(0),
|
||||||
logits,
|
logits,
|
||||||
logits.stride(0),
|
logits.stride(0),
|
||||||
seed,
|
seed,
|
||||||
pos,
|
pos,
|
||||||
is_greedy,
|
is_greedy,
|
||||||
vocab_size,
|
vocab_size,
|
||||||
num_warps=8,
|
BLOCK_SIZE=BLOCK_SIZE,
|
||||||
BLOCK_SIZE=16384, # type: ignore
|
|
||||||
)
|
)
|
||||||
|
# NOTE(woosuk): Use int64 for later indexing.
|
||||||
|
max_block_idx = local_max.argmax(dim=-1, keepdim=True)
|
||||||
|
sampled = local_argmax.gather(dim=-1, index=max_block_idx).view(-1)
|
||||||
return sampled
|
return sampled
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user