[perf]fix current stream (#11870)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2025-01-09 15:18:21 +08:00 committed by GitHub
parent a732900efc
commit 310aca88c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 46 additions and 15 deletions

View File

@ -10,6 +10,7 @@ from vllm.distributed.device_communicators.pynccl_wrapper import (
ncclRedOpTypeEnum, ncclUniqueId)
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.utils import current_stream
logger = init_logger(__name__)
@ -96,7 +97,7 @@ class PyNcclCommunicator:
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
self.world_size, self.unique_id, self.rank)
stream = torch.cuda.current_stream()
stream = current_stream()
# A small all_reduce for warmup.
data = torch.zeros(1, device=device)
self.all_reduce(data)
@ -119,7 +120,7 @@ class PyNcclCommunicator:
out_tensor = torch.empty_like(in_tensor)
if stream is None:
stream = torch.cuda.current_stream()
stream = current_stream()
self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
buffer_type(out_tensor.data_ptr()),
in_tensor.numel(),
@ -141,7 +142,7 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
if stream is None:
stream = torch.cuda.current_stream()
stream = current_stream()
self.nccl.ncclAllGather(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
@ -162,7 +163,7 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
if stream is None:
stream = torch.cuda.current_stream()
stream = current_stream()
self.nccl.ncclReduceScatter(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
@ -177,7 +178,7 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = torch.cuda.current_stream()
stream = current_stream()
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
self.comm, cudaStream_t(stream.cuda_stream))
@ -189,7 +190,7 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = torch.cuda.current_stream()
stream = current_stream()
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, cudaStream_t(stream.cuda_stream))
@ -201,7 +202,7 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = torch.cuda.current_stream()
stream = current_stream()
if src == self.rank:
sendbuff = buffer_type(tensor.data_ptr())
# NCCL requires the sender also to have a receive buffer

View File

@ -357,10 +357,7 @@ class GroupCoordinator:
return out
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
# TODO: pynccl should not use `stream=`
# it can just always use the current stream.
out = pynccl_comm.all_reduce(input_,
stream=torch.cuda.current_stream())
out = pynccl_comm.all_reduce(input_)
if out is None:
# fall back to the default all-reduce using PyTorch.
# this usually happens during testing.

View File

@ -944,6 +944,39 @@ def find_nccl_library() -> str:
return so_file
prev_set_stream = torch.cuda.set_stream
_current_stream = None
def _patched_set_stream(stream: torch.cuda.Stream) -> None:
global _current_stream
_current_stream = stream
prev_set_stream(stream)
torch.cuda.set_stream = _patched_set_stream
def current_stream() -> torch.cuda.Stream:
"""
replace `torch.cuda.current_stream()` with `vllm.utils.current_stream()`.
it turns out that `torch.cuda.current_stream()` is quite expensive,
as it will construct a new stream object at each call.
here we patch `torch.cuda.set_stream` to keep track of the current stream
directly, so that we can avoid calling `torch.cuda.current_stream()`.
the underlying hypothesis is that we do not call `torch._C._cuda_setStream`
from C/C++ code.
"""
global _current_stream
if _current_stream is None:
# when this function is called before any stream is set,
# we return the default stream.
_current_stream = torch.cuda.current_stream()
return _current_stream
def enable_trace_function_call_for_thread(vllm_config: "VllmConfig") -> None:
"""Set up function tracing for the current thread,
if enabled via the VLLM_TRACE_FUNCTION environment variable

View File

@ -14,7 +14,7 @@ from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs,
get_pythonized_sample_results)
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
Logprob, SequenceGroupMetadata, SequenceOutput)
from vllm.utils import PyObjectCache, async_tensor_h2d
from vllm.utils import PyObjectCache, async_tensor_h2d, current_stream
from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUWithSamplingMetadata)
from vllm.worker.model_runner_base import (
@ -498,7 +498,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
# appended sampler output from last iteration
# - also maybe pythonize if CPU is ahead of GPU
current_stream = torch.cuda.current_stream()
stream = current_stream()
if not model_input.is_first_multi_step:
# Explicitly block on the previous step's forward to make sure we
# don't clobber any GPU tensors still in use.
@ -541,7 +541,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
num_steps=1)
# record the event for the current step so that the next step can sync
model_input.record_step_event(current_stream)
model_input.record_step_event(stream)
if get_pp_group().is_last_rank and self.is_driver_worker:
assert isinstance(output, list)
@ -552,7 +552,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
# event for the pythonization so that we only pythonize if the
# tensors are ready. May be able to be combined with the step event
output_ready_event = torch.cuda.Event()
output_ready_event.record(current_stream)
output_ready_event.record(stream)
if self.parallel_config.pipeline_parallel_size > 1:
output[0].sampled_token_ids_cpu = output[
0].sampled_token_ids.cpu()