diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index 2184a1866ff5..6d82d3a79c8e 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -9,7 +9,7 @@ import torch @dataclass class SamplingMetadata: - temperature: torch.Tensor + temperature: Optional[torch.Tensor] all_greedy: bool all_random: bool diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 8e2533eefab0..ff978b3b6c41 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -77,11 +77,8 @@ class Sampler(nn.Module): logits: torch.Tensor, temp: torch.Tensor, ) -> torch.Tensor: - # Avoid division by zero. - temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp) # Use in-place division to avoid creating a new tensor. - logits.div_(temp.unsqueeze(dim=1)) - return logits + return logits.div_(temp.unsqueeze(dim=1)) def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor: return logits.argmax(dim=-1).view(-1) @@ -100,6 +97,8 @@ class Sampler(nn.Module): if sampling_metadata.all_greedy: return greedy_sampled + assert sampling_metadata.temperature is not None + # Apply temperature. logits = self.apply_temperature(logits, sampling_metadata.temperature) @@ -122,6 +121,7 @@ class Sampler(nn.Module): sampling_metadata.temperature < _SAMPLING_EPS, greedy_sampled, random_sampled, + out=greedy_sampled, # Reuse tensor ) return sampled diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 5be465014242..62271255b0c0 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -191,11 +191,13 @@ def bind_kv_cache( def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor, - length: int) -> None: + length: int) -> torch.Tensor: """ Copy the first length elements of a tensor into another tensor in a non-blocking manner. Used to copy pinned CPU tensor data to pre-allocated GPU tensors. + + Returns the sliced target tensor. """ - to_tensor[:length].copy_(from_tensor[:length], non_blocking=True) + return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index ccafc325b53f..bd1c369acb30 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -242,10 +242,12 @@ class InputBatch: self.block_table.add_row(req_index, request.block_ids) sampling_params = request.sampling_params - self.temperature_cpu[req_index] = sampling_params.temperature if sampling_params.sampling_type == SamplingType.GREEDY: + # Avoid later division by zero. + self.temperature_cpu[req_index] = -1.0 self.greedy_reqs.add(req_id) else: + self.temperature_cpu[req_index] = sampling_params.temperature self.random_reqs.add(req_id) self.top_p_cpu[req_index] = sampling_params.top_p @@ -410,7 +412,11 @@ class InputBatch: def _make_sampling_metadata(self) -> SamplingMetadata: num_reqs = self.num_reqs - copy_slice(self.temperature_cpu_tensor, self.temperature, num_reqs) + if not self.all_greedy: + temperature = copy_slice(self.temperature_cpu_tensor, + self.temperature, num_reqs) + else: + temperature = None if not self.no_top_p: copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs) if not self.no_top_k: @@ -437,7 +443,7 @@ class InputBatch: prompt_token_ids = None return SamplingMetadata( - temperature=self.temperature[:num_reqs], + temperature=temperature, all_greedy=self.all_greedy, all_random=self.all_random, top_p=None if self.no_top_p else self.top_p[:num_reqs],