mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-10 23:43:33 +08:00
[Model Runner V2] Add apply_temperature option to gumbel_sample (#29276)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
62d54ba46d
commit
3e1ad40655
@ -68,9 +68,10 @@ class Sampler:
|
|||||||
|
|
||||||
sampled = gumbel_sample(
|
sampled = gumbel_sample(
|
||||||
logits,
|
logits,
|
||||||
is_greedy,
|
sampling_metadata.temperature,
|
||||||
sampling_metadata.seeds,
|
sampling_metadata.seeds,
|
||||||
sampling_metadata.pos,
|
sampling_metadata.pos,
|
||||||
|
apply_temperature=False,
|
||||||
)
|
)
|
||||||
return sampled, logits if return_logits else None
|
return sampled, logits if return_logits else None
|
||||||
|
|
||||||
@ -85,9 +86,10 @@ def _gumbel_sample_kernel(
|
|||||||
logits_stride,
|
logits_stride,
|
||||||
seeds_ptr,
|
seeds_ptr,
|
||||||
pos_ptr,
|
pos_ptr,
|
||||||
is_greedy_ptr,
|
temp_ptr,
|
||||||
vocab_size,
|
vocab_size,
|
||||||
BLOCK_SIZE: tl.constexpr,
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
APPLY_TEMPERATURE: tl.constexpr,
|
||||||
):
|
):
|
||||||
req_idx = tl.program_id(0)
|
req_idx = tl.program_id(0)
|
||||||
block_idx = tl.program_id(1)
|
block_idx = tl.program_id(1)
|
||||||
@ -99,8 +101,8 @@ def _gumbel_sample_kernel(
|
|||||||
other=float("-inf"),
|
other=float("-inf"),
|
||||||
)
|
)
|
||||||
|
|
||||||
is_greedy = tl.load(is_greedy_ptr + req_idx)
|
temp = tl.load(temp_ptr + req_idx)
|
||||||
if not is_greedy:
|
if temp != 0.0:
|
||||||
# Calculate the seed for gumbel noise.
|
# Calculate the seed for gumbel noise.
|
||||||
seed = tl.load(seeds_ptr + req_idx)
|
seed = tl.load(seeds_ptr + req_idx)
|
||||||
pos = tl.load(pos_ptr + req_idx)
|
pos = tl.load(pos_ptr + req_idx)
|
||||||
@ -111,6 +113,11 @@ def _gumbel_sample_kernel(
|
|||||||
gumbel_noise = -tl.log(-tl.log(r + 1e-20) + 1e-20)
|
gumbel_noise = -tl.log(-tl.log(r + 1e-20) + 1e-20)
|
||||||
gumbel_noise = gumbel_noise.to(tl.float32)
|
gumbel_noise = gumbel_noise.to(tl.float32)
|
||||||
|
|
||||||
|
# Apply temperature.
|
||||||
|
if APPLY_TEMPERATURE:
|
||||||
|
# NOTE(woosuk): Use div_rn to match the behavior of torch.
|
||||||
|
logits = tl.div_rn(logits, temp.to(tl.float32))
|
||||||
|
|
||||||
# Apply gumbel noise.
|
# Apply gumbel noise.
|
||||||
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
|
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
|
||||||
|
|
||||||
@ -123,9 +130,10 @@ def _gumbel_sample_kernel(
|
|||||||
|
|
||||||
def gumbel_sample(
|
def gumbel_sample(
|
||||||
logits: torch.Tensor, # [num_reqs, vocab_size]
|
logits: torch.Tensor, # [num_reqs, vocab_size]
|
||||||
is_greedy: torch.Tensor, # [num_reqs]
|
temperature: torch.Tensor, # [num_reqs]
|
||||||
seed: torch.Tensor, # [num_reqs]
|
seed: torch.Tensor, # [num_reqs]
|
||||||
pos: torch.Tensor, # [num_reqs]
|
pos: torch.Tensor, # [num_reqs]
|
||||||
|
apply_temperature: bool,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
num_reqs, vocab_size = logits.shape
|
num_reqs, vocab_size = logits.shape
|
||||||
BLOCK_SIZE = 1024
|
BLOCK_SIZE = 1024
|
||||||
@ -151,9 +159,10 @@ def gumbel_sample(
|
|||||||
logits.stride(0),
|
logits.stride(0),
|
||||||
seed,
|
seed,
|
||||||
pos,
|
pos,
|
||||||
is_greedy,
|
temperature,
|
||||||
vocab_size,
|
vocab_size,
|
||||||
BLOCK_SIZE=BLOCK_SIZE,
|
BLOCK_SIZE=BLOCK_SIZE,
|
||||||
|
APPLY_TEMPERATURE=apply_temperature,
|
||||||
)
|
)
|
||||||
# NOTE(woosuk): Use int64 for later indexing.
|
# NOTE(woosuk): Use int64 for later indexing.
|
||||||
max_block_idx = local_max.argmax(dim=-1, keepdim=True)
|
max_block_idx = local_max.argmax(dim=-1, keepdim=True)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user