diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index fa2a6e590f22d..83ea766b1b4ad 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -128,8 +128,12 @@ class Sampler(nn.Module): self, logits: torch.Tensor, temp: torch.Tensor, + all_random: bool, ) -> torch.Tensor: # Use in-place division to avoid creating a new tensor. + # Avoid division by zero if there are greedy requests. + if not all_random: + temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp) return logits.div_(temp.unsqueeze(dim=1)) def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor: @@ -164,7 +168,8 @@ class Sampler(nn.Module): assert sampling_metadata.temperature is not None # Apply temperature. - logits = self.apply_temperature(logits, sampling_metadata.temperature) + logits = self.apply_temperature(logits, sampling_metadata.temperature, + sampling_metadata.all_random) # Apply logits processors that only apply to random sampling # (argmax invariant) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 79a392337574f..67fb9864b19c9 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -354,8 +354,8 @@ class InputBatch: and is_spec_decode_unsupported(sampling_params)): self.spec_decode_unsupported_reqs.add(req_id) if sampling_params.sampling_type == SamplingType.GREEDY: - # Avoid later division by zero. - self.temperature_cpu[req_index] = -1.0 + # Should avoid division by zero later when apply_temperature. + self.temperature_cpu[req_index] = 0.0 self.greedy_reqs.add(req_id) else: self.temperature_cpu[req_index] = sampling_params.temperature