mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 09:37:25 +08:00
[V1][Sampler] Avoid an operation during temperature application (#13587)
This commit is contained in:
parent
a30c093502
commit
31aa045c11
@ -9,7 +9,7 @@ import torch
|
||||
@dataclass
|
||||
class SamplingMetadata:
|
||||
|
||||
temperature: torch.Tensor
|
||||
temperature: Optional[torch.Tensor]
|
||||
all_greedy: bool
|
||||
all_random: bool
|
||||
|
||||
|
||||
@ -77,11 +77,8 @@ class Sampler(nn.Module):
|
||||
logits: torch.Tensor,
|
||||
temp: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Avoid division by zero.
|
||||
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
|
||||
# Use in-place division to avoid creating a new tensor.
|
||||
logits.div_(temp.unsqueeze(dim=1))
|
||||
return logits
|
||||
return logits.div_(temp.unsqueeze(dim=1))
|
||||
|
||||
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
return logits.argmax(dim=-1).view(-1)
|
||||
@ -100,6 +97,8 @@ class Sampler(nn.Module):
|
||||
if sampling_metadata.all_greedy:
|
||||
return greedy_sampled
|
||||
|
||||
assert sampling_metadata.temperature is not None
|
||||
|
||||
# Apply temperature.
|
||||
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
||||
|
||||
@ -122,6 +121,7 @@ class Sampler(nn.Module):
|
||||
sampling_metadata.temperature < _SAMPLING_EPS,
|
||||
greedy_sampled,
|
||||
random_sampled,
|
||||
out=greedy_sampled, # Reuse tensor
|
||||
)
|
||||
return sampled
|
||||
|
||||
|
||||
@ -191,11 +191,13 @@ def bind_kv_cache(
|
||||
|
||||
|
||||
def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
|
||||
length: int) -> None:
|
||||
length: int) -> torch.Tensor:
|
||||
"""
|
||||
Copy the first length elements of a tensor into another tensor in a
|
||||
non-blocking manner.
|
||||
|
||||
Used to copy pinned CPU tensor data to pre-allocated GPU tensors.
|
||||
|
||||
Returns the sliced target tensor.
|
||||
"""
|
||||
to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
|
||||
return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
|
||||
|
||||
@ -242,10 +242,12 @@ class InputBatch:
|
||||
self.block_table.add_row(req_index, request.block_ids)
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
self.temperature_cpu[req_index] = sampling_params.temperature
|
||||
if sampling_params.sampling_type == SamplingType.GREEDY:
|
||||
# Avoid later division by zero.
|
||||
self.temperature_cpu[req_index] = -1.0
|
||||
self.greedy_reqs.add(req_id)
|
||||
else:
|
||||
self.temperature_cpu[req_index] = sampling_params.temperature
|
||||
self.random_reqs.add(req_id)
|
||||
|
||||
self.top_p_cpu[req_index] = sampling_params.top_p
|
||||
@ -410,7 +412,11 @@ class InputBatch:
|
||||
|
||||
def _make_sampling_metadata(self) -> SamplingMetadata:
|
||||
num_reqs = self.num_reqs
|
||||
copy_slice(self.temperature_cpu_tensor, self.temperature, num_reqs)
|
||||
if not self.all_greedy:
|
||||
temperature = copy_slice(self.temperature_cpu_tensor,
|
||||
self.temperature, num_reqs)
|
||||
else:
|
||||
temperature = None
|
||||
if not self.no_top_p:
|
||||
copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
|
||||
if not self.no_top_k:
|
||||
@ -437,7 +443,7 @@ class InputBatch:
|
||||
prompt_token_ids = None
|
||||
|
||||
return SamplingMetadata(
|
||||
temperature=self.temperature[:num_reqs],
|
||||
temperature=temperature,
|
||||
all_greedy=self.all_greedy,
|
||||
all_random=self.all_random,
|
||||
top_p=None if self.no_top_p else self.top_p[:num_reqs],
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user