[Model Runner V2] Add apply_temperature option to gumbel_sample (#29276)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-11-23 14:13:00 -08:00 committed by GitHub
parent 62d54ba46d
commit 3e1ad40655
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -68,9 +68,10 @@ class Sampler:
sampled = gumbel_sample(
logits,
is_greedy,
sampling_metadata.temperature,
sampling_metadata.seeds,
sampling_metadata.pos,
apply_temperature=False,
)
return sampled, logits if return_logits else None
@ -85,9 +86,10 @@ def _gumbel_sample_kernel(
logits_stride,
seeds_ptr,
pos_ptr,
is_greedy_ptr,
temp_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
APPLY_TEMPERATURE: tl.constexpr,
):
req_idx = tl.program_id(0)
block_idx = tl.program_id(1)
@ -99,8 +101,8 @@ def _gumbel_sample_kernel(
other=float("-inf"),
)
is_greedy = tl.load(is_greedy_ptr + req_idx)
if not is_greedy:
temp = tl.load(temp_ptr + req_idx)
if temp != 0.0:
# Calculate the seed for gumbel noise.
seed = tl.load(seeds_ptr + req_idx)
pos = tl.load(pos_ptr + req_idx)
@ -111,6 +113,11 @@ def _gumbel_sample_kernel(
gumbel_noise = -tl.log(-tl.log(r + 1e-20) + 1e-20)
gumbel_noise = gumbel_noise.to(tl.float32)
# Apply temperature.
if APPLY_TEMPERATURE:
# NOTE(woosuk): Use div_rn to match the behavior of torch.
logits = tl.div_rn(logits, temp.to(tl.float32))
# Apply gumbel noise.
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
@ -123,9 +130,10 @@ def _gumbel_sample_kernel(
def gumbel_sample(
logits: torch.Tensor, # [num_reqs, vocab_size]
is_greedy: torch.Tensor, # [num_reqs]
temperature: torch.Tensor, # [num_reqs]
seed: torch.Tensor, # [num_reqs]
pos: torch.Tensor, # [num_reqs]
apply_temperature: bool,
) -> torch.Tensor:
num_reqs, vocab_size = logits.shape
BLOCK_SIZE = 1024
@ -151,9 +159,10 @@ def gumbel_sample(
logits.stride(0),
seed,
pos,
is_greedy,
temperature,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
APPLY_TEMPERATURE=apply_temperature,
)
# NOTE(woosuk): Use int64 for later indexing.
max_block_idx = local_max.argmax(dim=-1, keepdim=True)