diff --git a/vllm/v1/sample/logits_processor.py b/vllm/v1/sample/logits_processor.py index 16bd2b9ffd841..3a4c25964e708 100644 --- a/vllm/v1/sample/logits_processor.py +++ b/vllm/v1/sample/logits_processor.py @@ -234,10 +234,16 @@ class MinPLogitsProcessor(LogitsProcessor): device="cpu", pin_memory=pin_memory) self.min_p_cpu = self.min_p_cpu_tensor.numpy() - # Pre-allocated device tensor - self.min_p_device: torch.Tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) + + self.use_double_tensor = torch.device("cpu") != torch.device(device) + + if self.use_double_tensor: + # Pre-allocated device tensor + self.min_p_device: torch.Tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + else: + self.min_p_device = self.min_p_cpu_tensor # Current slice of the device tensor self.min_p: torch.Tensor = self.min_p_device[:0] @@ -284,7 +290,9 @@ class MinPLogitsProcessor(LogitsProcessor): size = batch_update.batch_size if self.min_p_count and (needs_update or self.min_p.shape[0] != size): self.min_p = self.min_p_device[:size] - self.min_p.copy_(self.min_p_cpu_tensor[:size], non_blocking=True) + if self.use_double_tensor: + self.min_p.copy_(self.min_p_cpu_tensor[:size], + non_blocking=True) self.min_p.unsqueeze_(1) def apply(self, logits: torch.Tensor) -> torch.Tensor: