mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 16:56:18 +08:00
[V1][Sampler] Avoid an operation during temperature application (#13587)
This commit is contained in:
parent
a30c093502
commit
31aa045c11
@ -9,7 +9,7 @@ import torch
|
|||||||
@dataclass
|
@dataclass
|
||||||
class SamplingMetadata:
|
class SamplingMetadata:
|
||||||
|
|
||||||
temperature: torch.Tensor
|
temperature: Optional[torch.Tensor]
|
||||||
all_greedy: bool
|
all_greedy: bool
|
||||||
all_random: bool
|
all_random: bool
|
||||||
|
|
||||||
|
|||||||
@ -77,11 +77,8 @@ class Sampler(nn.Module):
|
|||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
temp: torch.Tensor,
|
temp: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Avoid division by zero.
|
|
||||||
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
|
|
||||||
# Use in-place division to avoid creating a new tensor.
|
# Use in-place division to avoid creating a new tensor.
|
||||||
logits.div_(temp.unsqueeze(dim=1))
|
return logits.div_(temp.unsqueeze(dim=1))
|
||||||
return logits
|
|
||||||
|
|
||||||
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
|
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
|
||||||
return logits.argmax(dim=-1).view(-1)
|
return logits.argmax(dim=-1).view(-1)
|
||||||
@ -100,6 +97,8 @@ class Sampler(nn.Module):
|
|||||||
if sampling_metadata.all_greedy:
|
if sampling_metadata.all_greedy:
|
||||||
return greedy_sampled
|
return greedy_sampled
|
||||||
|
|
||||||
|
assert sampling_metadata.temperature is not None
|
||||||
|
|
||||||
# Apply temperature.
|
# Apply temperature.
|
||||||
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
||||||
|
|
||||||
@ -122,6 +121,7 @@ class Sampler(nn.Module):
|
|||||||
sampling_metadata.temperature < _SAMPLING_EPS,
|
sampling_metadata.temperature < _SAMPLING_EPS,
|
||||||
greedy_sampled,
|
greedy_sampled,
|
||||||
random_sampled,
|
random_sampled,
|
||||||
|
out=greedy_sampled, # Reuse tensor
|
||||||
)
|
)
|
||||||
return sampled
|
return sampled
|
||||||
|
|
||||||
|
|||||||
@ -191,11 +191,13 @@ def bind_kv_cache(
|
|||||||
|
|
||||||
|
|
||||||
def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
|
def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
|
||||||
length: int) -> None:
|
length: int) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Copy the first length elements of a tensor into another tensor in a
|
Copy the first length elements of a tensor into another tensor in a
|
||||||
non-blocking manner.
|
non-blocking manner.
|
||||||
|
|
||||||
Used to copy pinned CPU tensor data to pre-allocated GPU tensors.
|
Used to copy pinned CPU tensor data to pre-allocated GPU tensors.
|
||||||
|
|
||||||
|
Returns the sliced target tensor.
|
||||||
"""
|
"""
|
||||||
to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
|
return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
|
||||||
|
|||||||
@ -242,10 +242,12 @@ class InputBatch:
|
|||||||
self.block_table.add_row(req_index, request.block_ids)
|
self.block_table.add_row(req_index, request.block_ids)
|
||||||
|
|
||||||
sampling_params = request.sampling_params
|
sampling_params = request.sampling_params
|
||||||
self.temperature_cpu[req_index] = sampling_params.temperature
|
|
||||||
if sampling_params.sampling_type == SamplingType.GREEDY:
|
if sampling_params.sampling_type == SamplingType.GREEDY:
|
||||||
|
# Avoid later division by zero.
|
||||||
|
self.temperature_cpu[req_index] = -1.0
|
||||||
self.greedy_reqs.add(req_id)
|
self.greedy_reqs.add(req_id)
|
||||||
else:
|
else:
|
||||||
|
self.temperature_cpu[req_index] = sampling_params.temperature
|
||||||
self.random_reqs.add(req_id)
|
self.random_reqs.add(req_id)
|
||||||
|
|
||||||
self.top_p_cpu[req_index] = sampling_params.top_p
|
self.top_p_cpu[req_index] = sampling_params.top_p
|
||||||
@ -410,7 +412,11 @@ class InputBatch:
|
|||||||
|
|
||||||
def _make_sampling_metadata(self) -> SamplingMetadata:
|
def _make_sampling_metadata(self) -> SamplingMetadata:
|
||||||
num_reqs = self.num_reqs
|
num_reqs = self.num_reqs
|
||||||
copy_slice(self.temperature_cpu_tensor, self.temperature, num_reqs)
|
if not self.all_greedy:
|
||||||
|
temperature = copy_slice(self.temperature_cpu_tensor,
|
||||||
|
self.temperature, num_reqs)
|
||||||
|
else:
|
||||||
|
temperature = None
|
||||||
if not self.no_top_p:
|
if not self.no_top_p:
|
||||||
copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
|
copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
|
||||||
if not self.no_top_k:
|
if not self.no_top_k:
|
||||||
@ -437,7 +443,7 @@ class InputBatch:
|
|||||||
prompt_token_ids = None
|
prompt_token_ids = None
|
||||||
|
|
||||||
return SamplingMetadata(
|
return SamplingMetadata(
|
||||||
temperature=self.temperature[:num_reqs],
|
temperature=temperature,
|
||||||
all_greedy=self.all_greedy,
|
all_greedy=self.all_greedy,
|
||||||
all_random=self.all_random,
|
all_random=self.all_random,
|
||||||
top_p=None if self.no_top_p else self.top_p[:num_reqs],
|
top_p=None if self.no_top_p else self.top_p[:num_reqs],
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user