mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:05:01 +08:00
[perf]fix current stream (#11870)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
a732900efc
commit
310aca88c9
@ -10,6 +10,7 @@ from vllm.distributed.device_communicators.pynccl_wrapper import (
|
||||
ncclRedOpTypeEnum, ncclUniqueId)
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import current_stream
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -96,7 +97,7 @@ class PyNcclCommunicator:
|
||||
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
|
||||
self.world_size, self.unique_id, self.rank)
|
||||
|
||||
stream = torch.cuda.current_stream()
|
||||
stream = current_stream()
|
||||
# A small all_reduce for warmup.
|
||||
data = torch.zeros(1, device=device)
|
||||
self.all_reduce(data)
|
||||
@ -119,7 +120,7 @@ class PyNcclCommunicator:
|
||||
out_tensor = torch.empty_like(in_tensor)
|
||||
|
||||
if stream is None:
|
||||
stream = torch.cuda.current_stream()
|
||||
stream = current_stream()
|
||||
self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
|
||||
buffer_type(out_tensor.data_ptr()),
|
||||
in_tensor.numel(),
|
||||
@ -141,7 +142,7 @@ class PyNcclCommunicator:
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {input_tensor.device}")
|
||||
if stream is None:
|
||||
stream = torch.cuda.current_stream()
|
||||
stream = current_stream()
|
||||
self.nccl.ncclAllGather(
|
||||
buffer_type(input_tensor.data_ptr()),
|
||||
buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
|
||||
@ -162,7 +163,7 @@ class PyNcclCommunicator:
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {input_tensor.device}")
|
||||
if stream is None:
|
||||
stream = torch.cuda.current_stream()
|
||||
stream = current_stream()
|
||||
self.nccl.ncclReduceScatter(
|
||||
buffer_type(input_tensor.data_ptr()),
|
||||
buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
|
||||
@ -177,7 +178,7 @@ class PyNcclCommunicator:
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
if stream is None:
|
||||
stream = torch.cuda.current_stream()
|
||||
stream = current_stream()
|
||||
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
|
||||
self.comm, cudaStream_t(stream.cuda_stream))
|
||||
@ -189,7 +190,7 @@ class PyNcclCommunicator:
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
if stream is None:
|
||||
stream = torch.cuda.current_stream()
|
||||
stream = current_stream()
|
||||
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype), src,
|
||||
self.comm, cudaStream_t(stream.cuda_stream))
|
||||
@ -201,7 +202,7 @@ class PyNcclCommunicator:
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
if stream is None:
|
||||
stream = torch.cuda.current_stream()
|
||||
stream = current_stream()
|
||||
if src == self.rank:
|
||||
sendbuff = buffer_type(tensor.data_ptr())
|
||||
# NCCL requires the sender also to have a receive buffer
|
||||
|
||||
@ -357,10 +357,7 @@ class GroupCoordinator:
|
||||
return out
|
||||
pynccl_comm = self.pynccl_comm
|
||||
assert pynccl_comm is not None
|
||||
# TODO: pynccl should not use `stream=`
|
||||
# it can just always use the current stream.
|
||||
out = pynccl_comm.all_reduce(input_,
|
||||
stream=torch.cuda.current_stream())
|
||||
out = pynccl_comm.all_reduce(input_)
|
||||
if out is None:
|
||||
# fall back to the default all-reduce using PyTorch.
|
||||
# this usually happens during testing.
|
||||
|
||||
@ -944,6 +944,39 @@ def find_nccl_library() -> str:
|
||||
return so_file
|
||||
|
||||
|
||||
prev_set_stream = torch.cuda.set_stream
|
||||
|
||||
_current_stream = None
|
||||
|
||||
|
||||
def _patched_set_stream(stream: torch.cuda.Stream) -> None:
|
||||
global _current_stream
|
||||
_current_stream = stream
|
||||
prev_set_stream(stream)
|
||||
|
||||
|
||||
torch.cuda.set_stream = _patched_set_stream
|
||||
|
||||
|
||||
def current_stream() -> torch.cuda.Stream:
|
||||
"""
|
||||
replace `torch.cuda.current_stream()` with `vllm.utils.current_stream()`.
|
||||
it turns out that `torch.cuda.current_stream()` is quite expensive,
|
||||
as it will construct a new stream object at each call.
|
||||
here we patch `torch.cuda.set_stream` to keep track of the current stream
|
||||
directly, so that we can avoid calling `torch.cuda.current_stream()`.
|
||||
|
||||
the underlying hypothesis is that we do not call `torch._C._cuda_setStream`
|
||||
from C/C++ code.
|
||||
"""
|
||||
global _current_stream
|
||||
if _current_stream is None:
|
||||
# when this function is called before any stream is set,
|
||||
# we return the default stream.
|
||||
_current_stream = torch.cuda.current_stream()
|
||||
return _current_stream
|
||||
|
||||
|
||||
def enable_trace_function_call_for_thread(vllm_config: "VllmConfig") -> None:
|
||||
"""Set up function tracing for the current thread,
|
||||
if enabled via the VLLM_TRACE_FUNCTION environment variable
|
||||
|
||||
@ -14,7 +14,7 @@ from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs,
|
||||
get_pythonized_sample_results)
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
|
||||
Logprob, SequenceGroupMetadata, SequenceOutput)
|
||||
from vllm.utils import PyObjectCache, async_tensor_h2d
|
||||
from vllm.utils import PyObjectCache, async_tensor_h2d, current_stream
|
||||
from vllm.worker.model_runner import (GPUModelRunnerBase,
|
||||
ModelInputForGPUWithSamplingMetadata)
|
||||
from vllm.worker.model_runner_base import (
|
||||
@ -498,7 +498,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
# appended sampler output from last iteration
|
||||
# - also maybe pythonize if CPU is ahead of GPU
|
||||
|
||||
current_stream = torch.cuda.current_stream()
|
||||
stream = current_stream()
|
||||
if not model_input.is_first_multi_step:
|
||||
# Explicitly block on the previous step's forward to make sure we
|
||||
# don't clobber any GPU tensors still in use.
|
||||
@ -541,7 +541,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
num_steps=1)
|
||||
|
||||
# record the event for the current step so that the next step can sync
|
||||
model_input.record_step_event(current_stream)
|
||||
model_input.record_step_event(stream)
|
||||
|
||||
if get_pp_group().is_last_rank and self.is_driver_worker:
|
||||
assert isinstance(output, list)
|
||||
@ -552,7 +552,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
# event for the pythonization so that we only pythonize if the
|
||||
# tensors are ready. May be able to be combined with the step event
|
||||
output_ready_event = torch.cuda.Event()
|
||||
output_ready_event.record(current_stream)
|
||||
output_ready_event.record(stream)
|
||||
if self.parallel_config.pipeline_parallel_size > 1:
|
||||
output[0].sampled_token_ids_cpu = output[
|
||||
0].sampled_token_ids.cpu()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user