mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 12:04:28 +08:00
[Model Runner V2] Add apply_temperature option to gumbel_sample (#29276)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
62d54ba46d
commit
3e1ad40655
@ -68,9 +68,10 @@ class Sampler:
|
||||
|
||||
sampled = gumbel_sample(
|
||||
logits,
|
||||
is_greedy,
|
||||
sampling_metadata.temperature,
|
||||
sampling_metadata.seeds,
|
||||
sampling_metadata.pos,
|
||||
apply_temperature=False,
|
||||
)
|
||||
return sampled, logits if return_logits else None
|
||||
|
||||
@ -85,9 +86,10 @@ def _gumbel_sample_kernel(
|
||||
logits_stride,
|
||||
seeds_ptr,
|
||||
pos_ptr,
|
||||
is_greedy_ptr,
|
||||
temp_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
APPLY_TEMPERATURE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
block_idx = tl.program_id(1)
|
||||
@ -99,8 +101,8 @@ def _gumbel_sample_kernel(
|
||||
other=float("-inf"),
|
||||
)
|
||||
|
||||
is_greedy = tl.load(is_greedy_ptr + req_idx)
|
||||
if not is_greedy:
|
||||
temp = tl.load(temp_ptr + req_idx)
|
||||
if temp != 0.0:
|
||||
# Calculate the seed for gumbel noise.
|
||||
seed = tl.load(seeds_ptr + req_idx)
|
||||
pos = tl.load(pos_ptr + req_idx)
|
||||
@ -111,6 +113,11 @@ def _gumbel_sample_kernel(
|
||||
gumbel_noise = -tl.log(-tl.log(r + 1e-20) + 1e-20)
|
||||
gumbel_noise = gumbel_noise.to(tl.float32)
|
||||
|
||||
# Apply temperature.
|
||||
if APPLY_TEMPERATURE:
|
||||
# NOTE(woosuk): Use div_rn to match the behavior of torch.
|
||||
logits = tl.div_rn(logits, temp.to(tl.float32))
|
||||
|
||||
# Apply gumbel noise.
|
||||
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
|
||||
|
||||
@ -123,9 +130,10 @@ def _gumbel_sample_kernel(
|
||||
|
||||
def gumbel_sample(
|
||||
logits: torch.Tensor, # [num_reqs, vocab_size]
|
||||
is_greedy: torch.Tensor, # [num_reqs]
|
||||
temperature: torch.Tensor, # [num_reqs]
|
||||
seed: torch.Tensor, # [num_reqs]
|
||||
pos: torch.Tensor, # [num_reqs]
|
||||
apply_temperature: bool,
|
||||
) -> torch.Tensor:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
BLOCK_SIZE = 1024
|
||||
@ -151,9 +159,10 @@ def gumbel_sample(
|
||||
logits.stride(0),
|
||||
seed,
|
||||
pos,
|
||||
is_greedy,
|
||||
temperature,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
APPLY_TEMPERATURE=apply_temperature,
|
||||
)
|
||||
# NOTE(woosuk): Use int64 for later indexing.
|
||||
max_block_idx = local_max.argmax(dim=-1, keepdim=True)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user