From bd29cf3d3ad3dd06105f1a4bb9023bb23bdfd5ed Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 20 Dec 2023 00:04:33 -0800 Subject: [PATCH] Remove Sampler copy stream (#2209) --- vllm/model_executor/layers/sampler.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index fe88b0ea4293..f9d95fa7548f 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -30,7 +30,6 @@ class Sampler(nn.Module): def __init__(self, vocab_size: int) -> None: super().__init__() self.vocab_size = vocab_size - self._copy_stream: torch.cuda.Stream = torch.cuda.Stream() def forward( self, @@ -51,14 +50,10 @@ class Sampler(nn.Module): # Apply logits processors (if any). logits = _apply_logits_processors(logits, sampling_metadata) - # Prepare sampling tensors in another stream to overlap - # CPU<->GPU data transfer with GPU computation in forward pass. - with torch.cuda.stream(self._copy_stream): - (sampling_tensors, do_penalties, do_top_p_top_k, - do_min_p) = SamplingTensors.from_sampling_metadata( - sampling_metadata, vocab_size, logits.device, logits.dtype) - - torch.cuda.current_stream().wait_stream(self._copy_stream) + # Prepare sampling tensors with pinned memory to avoid blocking. + (sampling_tensors, do_penalties, do_top_p_top_k, + do_min_p) = SamplingTensors.from_sampling_metadata( + sampling_metadata, vocab_size, logits.device, logits.dtype) # Apply presence and frequency penalties. if do_penalties: