[V1][Sampler] Avoid an operation during temperature application (#13587)

This commit is contained in:
Nick Hill 2025-02-20 22:05:56 -08:00 committed by GitHub
parent a30c093502
commit 31aa045c11
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 18 additions and 10 deletions

View File

@ -9,7 +9,7 @@ import torch
@dataclass
class SamplingMetadata:
temperature: torch.Tensor
temperature: Optional[torch.Tensor]
all_greedy: bool
all_random: bool

View File

@ -77,11 +77,8 @@ class Sampler(nn.Module):
logits: torch.Tensor,
temp: torch.Tensor,
) -> torch.Tensor:
# Avoid division by zero.
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
# Use in-place division to avoid creating a new tensor.
logits.div_(temp.unsqueeze(dim=1))
return logits
return logits.div_(temp.unsqueeze(dim=1))
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
return logits.argmax(dim=-1).view(-1)
@ -100,6 +97,8 @@ class Sampler(nn.Module):
if sampling_metadata.all_greedy:
return greedy_sampled
assert sampling_metadata.temperature is not None
# Apply temperature.
logits = self.apply_temperature(logits, sampling_metadata.temperature)
@ -122,6 +121,7 @@ class Sampler(nn.Module):
sampling_metadata.temperature < _SAMPLING_EPS,
greedy_sampled,
random_sampled,
out=greedy_sampled, # Reuse tensor
)
return sampled

View File

@ -191,11 +191,13 @@ def bind_kv_cache(
def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
length: int) -> None:
length: int) -> torch.Tensor:
"""
Copy the first length elements of a tensor into another tensor in a
non-blocking manner.
Used to copy pinned CPU tensor data to pre-allocated GPU tensors.
Returns the sliced target tensor.
"""
to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)

View File

@ -242,10 +242,12 @@ class InputBatch:
self.block_table.add_row(req_index, request.block_ids)
sampling_params = request.sampling_params
self.temperature_cpu[req_index] = sampling_params.temperature
if sampling_params.sampling_type == SamplingType.GREEDY:
# Avoid later division by zero.
self.temperature_cpu[req_index] = -1.0
self.greedy_reqs.add(req_id)
else:
self.temperature_cpu[req_index] = sampling_params.temperature
self.random_reqs.add(req_id)
self.top_p_cpu[req_index] = sampling_params.top_p
@ -410,7 +412,11 @@ class InputBatch:
def _make_sampling_metadata(self) -> SamplingMetadata:
num_reqs = self.num_reqs
copy_slice(self.temperature_cpu_tensor, self.temperature, num_reqs)
if not self.all_greedy:
temperature = copy_slice(self.temperature_cpu_tensor,
self.temperature, num_reqs)
else:
temperature = None
if not self.no_top_p:
copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
if not self.no_top_k:
@ -437,7 +443,7 @@ class InputBatch:
prompt_token_ids = None
return SamplingMetadata(
temperature=self.temperature[:num_reqs],
temperature=temperature,
all_greedy=self.all_greedy,
all_random=self.all_random,
top_p=None if self.no_top_p else self.top_p[:num_reqs],