diff --git a/requirements.txt b/requirements.txt index f509bdf4907b..e9023bde82a4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,3 @@ fastapi uvicorn[standard] pydantic == 1.10.13 # Required for OpenAI server. aioprometheus[starlette] -cupy-cuda12x # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead. # FIXME: Fix this in setup.py. diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 74070e238bb2..d91ab1430735 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -17,7 +17,7 @@ from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, SequenceOutput, SequenceStatus) from vllm.transformers_utils.tokenizer import (detokenize_incrementally, get_tokenizer) -from vllm.utils import Counter, get_open_port +from vllm.utils import Counter if ray: from ray.air.util.torch_dist import init_torch_dist_process_group @@ -190,7 +190,6 @@ class LLMEngine: )) self._run_workers( "init_model", - cupy_port=get_open_port(), get_all_outputs=True, ) self._run_workers( diff --git a/vllm/model_executor/parallel_utils/communication_op.py b/vllm/model_executor/parallel_utils/communication_op.py index e2dfa7ff601c..b1d5f5b9fb88 100644 --- a/vllm/model_executor/parallel_utils/communication_op.py +++ b/vllm/model_executor/parallel_utils/communication_op.py @@ -1,10 +1,8 @@ import torch -from vllm.model_executor.parallel_utils import cupy_utils from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size, get_tensor_model_parallel_group, - is_custom_nccl_enabled_for_all_reduce, ) @@ -17,12 +15,8 @@ def tensor_model_parallel_all_reduce(input_): if get_tensor_model_parallel_world_size() == 1: return input_ # All-reduce. - if is_custom_nccl_enabled_for_all_reduce(): - # TODO: support multiple parallel groups. - cupy_utils.all_reduce(input_) - else: - torch.distributed.all_reduce(input_, - group=get_tensor_model_parallel_group()) + torch.distributed.all_reduce(input_, + group=get_tensor_model_parallel_group()) return input_ diff --git a/vllm/model_executor/parallel_utils/cupy_utils.py b/vllm/model_executor/parallel_utils/cupy_utils.py deleted file mode 100644 index f4cbdf6a6506..000000000000 --- a/vllm/model_executor/parallel_utils/cupy_utils.py +++ /dev/null @@ -1,115 +0,0 @@ -"""CuPy utilities for all-reduce. - -We use CuPy all-reduce instead of torch.distributed.all_reduce when capturing -CUDA graphs, because torch.distributed.all_reduce causes errors when capturing -CUDA graphs. - -TODO: Remove this file when torch.distributed.all_reduce is fixed. -""" -import contextlib - -import torch -from torch.distributed import ReduceOp - -try: - import cupy - from cupyx.distributed import NCCLBackend - from cupy.cuda import nccl -except ImportError as e: - cupy = e - nccl = None - - class NCCLBackend: - ... - - -_OP_MAPPING = { - ReduceOp.SUM: "sum", - ReduceOp.PRODUCT: "prod", - ReduceOp.MIN: "min", - ReduceOp.MAX: "max", -} - - -class NCCLBackendWithBFloat16(NCCLBackend): - # This is enough to add bfloat16 support for most operations, - # but broadcast will fail (will require changes in compiled - # cupy code). - def _get_nccl_dtype_and_count(self, array, count=None): - nccl_dtype, count = super()._get_nccl_dtype_and_count(array, count) - torch_dtype = getattr(array, "_torch_dtype", None) - if torch_dtype is torch.bfloat16: - nccl_dtype = nccl.NCCL_BFLOAT16 - return nccl_dtype, count - - -_NCCL_BACKEND = None -_WORLD_SIZE = 0 - - -def is_initialized() -> bool: - """Returns whether the NCCL backend is initialized.""" - return _NCCL_BACKEND is not None - - -@contextlib.contextmanager -def set_cupy_stream(stream: torch.cuda.Stream) -> None: - """Set the cuda stream for communication""" - cupy_stream = cupy.cuda.ExternalStream(stream.cuda_stream, - stream.device_index) - with cupy_stream: - yield - - -def init_process_group(world_size: int, rank: int, host: str, - port: int) -> None: - """Initializes the CuPy NCCL backend. - - # TODO: handle NCCL timeouts. - """ - assert not is_initialized() - - if isinstance(cupy, Exception): - raise ImportError( - "NCCLBackend is not available. Please install cupy.") from cupy - - # TODO(woosuk): Create TP and PP process groups for CuPy. - global _NCCL_BACKEND - global _WORLD_SIZE - assert world_size > 0, f"{world_size=} should be a positive integer" - assert 0 <= rank < world_size, ( - f"{rank=} should be a integer between [0, {world_size})") - - cupy.cuda.runtime.setDevice(torch.cuda.current_device()) - _NCCL_BACKEND = NCCLBackendWithBFloat16(world_size, rank, host, port) - _WORLD_SIZE = world_size - - -def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None: - """All-reduces the input tensor across the process group.""" - assert input_.is_cuda, f"{input_} should be a cuda tensor" - # Hack to support bfloat16 - torch_dtype = input_.dtype - if torch_dtype is torch.bfloat16: - # We need to view as float16, otherwise - # cupy will fail. This will not change - # the underlying data. - input_ = input_.view(torch.float16) - cupy_input = cupy.asarray(input_) - cupy_input._torch_dtype = torch_dtype # pylint: disable=protected-access - _NCCL_BACKEND.all_reduce(in_array=cupy_input, - out_array=cupy_input, - op=_OP_MAPPING[op]) - - -def destroy_process_group() -> None: - """Destroys the NCCL backend.""" - global _NCCL_BACKEND - global _WORLD_SIZE - _NCCL_BACKEND = None - _WORLD_SIZE = 0 - - -def get_world_size() -> int: - """Returns the world size.""" - return _WORLD_SIZE diff --git a/vllm/model_executor/parallel_utils/parallel_state.py b/vllm/model_executor/parallel_utils/parallel_state.py index dac1b9c48eee..9a5e2889381d 100644 --- a/vllm/model_executor/parallel_utils/parallel_state.py +++ b/vllm/model_executor/parallel_utils/parallel_state.py @@ -3,12 +3,9 @@ # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """Tensor and pipeline parallel groups.""" -import contextlib import torch -from vllm.model_executor.parallel_utils import cupy_utils - # Tensor model parallel group that the current rank belongs to. _TENSOR_MODEL_PARALLEL_GROUP = None # Pipeline model parallel group that the current rank belongs to. @@ -180,37 +177,3 @@ def destroy_model_parallel(): _PIPELINE_MODEL_PARALLEL_GROUP = None global _PIPELINE_GLOBAL_RANKS _PIPELINE_GLOBAL_RANKS = None - - # Destroy the cupy states if any. - cupy_utils.destroy_process_group() - - -# Whether to use cupy for nccl all reduce. -# We use cupy for all reduce when using CUDA graph, because torch.distributed -# is not well supported by CUDA graph. -_ENABLE_CUPY_FOR_ALL_REDUCE = False - - -@contextlib.contextmanager -def with_custom_nccl_for_all_reduce(): - """use custom nccl instead of torch.distributed for all reduce""" - tp_size = get_tensor_model_parallel_world_size() - if tp_size == 1: - # No-op. - # NOTE(woosuk): We don't initialize CuPy when tp_size is 1. - yield - else: - global _ENABLE_CUPY_FOR_ALL_REDUCE - old = _ENABLE_CUPY_FOR_ALL_REDUCE - _ENABLE_CUPY_FOR_ALL_REDUCE = True - - stream = torch.cuda.current_stream() - with cupy_utils.set_cupy_stream(stream): - yield - _ENABLE_CUPY_FOR_ALL_REDUCE = old - - -def is_custom_nccl_enabled_for_all_reduce(): - """check if custom nccl is enabled for all reduce""" - global _ENABLE_CUPY_FOR_ALL_REDUCE - return _ENABLE_CUPY_FOR_ALL_REDUCE diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 359b1b2e1970..276ef0708847 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -8,8 +8,6 @@ import torch.nn as nn from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig from vllm.logger import init_logger from vllm.model_executor import get_model, InputMetadata, SamplingMetadata -from vllm.model_executor.parallel_utils.parallel_state import ( - with_custom_nccl_for_all_reduce) from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata @@ -459,8 +457,18 @@ class CUDAGraphRunner: # Run the model once without capturing the graph. # This is to make sure that the captured graph does not include the # kernel launches for initial benchmarking (e.g., Triton autotune). - with with_custom_nccl_for_all_reduce(): - self.model( + self.model( + input_ids, + positions, + kv_caches, + input_metadata, + ) + torch.cuda.synchronize() + + # Capture the graph. + self.graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(self.graph, pool=memory_pool): + hidden_states = self.model( input_ids, positions, kv_caches, @@ -468,20 +476,6 @@ class CUDAGraphRunner: ) torch.cuda.synchronize() - # Capture the graph. - # NOTE(woosuk): Python 3.8 does not support multi-line with statements. - # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement - self.graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117 - with with_custom_nccl_for_all_reduce(): - hidden_states = self.model( - input_ids, - positions, - kv_caches, - input_metadata, - ) - torch.cuda.synchronize() - # Save the input and output buffers. self.input_buffers = { "input_ids": input_ids, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 3e31737f2109..8698b1572150 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -8,7 +8,6 @@ import torch.distributed from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig) from vllm.model_executor import set_random_seed -from vllm.model_executor.parallel_utils import cupy_utils from vllm.model_executor.parallel_utils.parallel_state import ( initialize_model_parallel) from vllm.sequence import SamplerOutput, SequenceGroupMetadata @@ -47,7 +46,7 @@ class Worker: self.cache_events = None self.gpu_cache = None - def init_model(self, cupy_port: Optional[int] = None): + def init_model(self) -> None: # torch.distributed.all_reduce does not free the input tensor until # the synchronization point. This causes the memory usage to grow # as the number of all_reduce calls increases. This env var disables @@ -71,7 +70,7 @@ class Worker: # Initialize the distributed environment. _init_distributed_environment(self.parallel_config, self.rank, - cupy_port, self.distributed_init_method) + self.distributed_init_method) # Initialize the model. set_random_seed(self.model_config.seed) @@ -165,7 +164,6 @@ class Worker: def _init_distributed_environment( parallel_config: ParallelConfig, rank: int, - cupy_port: Optional[int], distributed_init_method: Optional[str] = None, ) -> None: """Initialize the distributed environment.""" @@ -188,29 +186,8 @@ def _init_distributed_environment( init_method=distributed_init_method, ) - if cupy_utils.is_initialized(): - cupy_world_size = cupy_utils.get_world_size() - if cupy_world_size != parallel_config.world_size: - raise RuntimeError( - "cupy.distributed is already initialized but the cupy world " - "size does not match parallel_config.world_size " - f"({cupy_world_size} vs. {parallel_config.world_size}).") - elif parallel_config.world_size > 1: - # NOTE(woosuk): We don't initialize CuPy process group when world size - # is 1. - # TODO(woosuk): Support multi-node connection. - cupy_utils.init_process_group( - world_size=parallel_config.world_size, - rank=rank, - host="localhost", - port=cupy_port, - ) - - if parallel_config.world_size > 1: - # A small all_reduce for warmup. - torch.distributed.all_reduce(torch.zeros(1).cuda()) - cupy_utils.all_reduce(torch.zeros(1).cuda()) - + # A small all_reduce for warmup. + torch.distributed.all_reduce(torch.zeros(1).cuda()) initialize_model_parallel(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size)