[Core][Distributed] remove graph mode function (#4818)

This commit is contained in:
youkaichao 2024-05-16 10:59:52 -07:00 committed by GitHub
parent b5853f9963
commit e08188081b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 63 additions and 54 deletions

View File

@ -50,7 +50,7 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
for sz in test_sizes: for sz in test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]: for dtype in [torch.float32, torch.float16, torch.bfloat16]:
with graph_capture(): with graph_capture() as graph_capture_context:
# use integers so result matches NCCL exactly # use integers so result matches NCCL exactly
inp1 = torch.randint(1, inp1 = torch.randint(1,
16, (sz, ), 16, (sz, ),
@ -62,7 +62,8 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
device=torch.cuda.current_device()) device=torch.cuda.current_device())
torch.cuda.synchronize() torch.cuda.synchronize()
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph): with torch.cuda.graph(graph,
stream=graph_capture_context.stream):
for i in range(num_communication): for i in range(num_communication):
out1 = tensor_model_parallel_all_reduce(inp1) out1 = tensor_model_parallel_all_reduce(inp1)
# the input buffer is immediately modified to test # the input buffer is immediately modified to test

View File

@ -5,7 +5,7 @@ import pytest
import torch import torch
from vllm.distributed.communication_op import ( # noqa from vllm.distributed.communication_op import ( # noqa
graph_mode, tensor_model_parallel_all_reduce) graph_capture, tensor_model_parallel_all_reduce)
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
@ -103,7 +103,7 @@ def multiple_tp_with_vllm_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}") device = torch.device(f"cuda:{torch.distributed.get_rank()}")
ensure_model_parallel_initialized(2, 2) ensure_model_parallel_initialized(2, 2)
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
with graph_mode(): with graph_capture():
# two tp groups can communicate independently # two tp groups can communicate independently
if torch.distributed.get_rank() in [0, 1]: if torch.distributed.get_rank() in [0, 1]:
tensor = tensor_model_parallel_all_reduce(tensor) tensor = tensor_model_parallel_all_reduce(tensor)

View File

@ -1,5 +1,6 @@
from collections import namedtuple from collections import namedtuple
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
@ -13,45 +14,54 @@ from .parallel_state import (get_cpu_world_group,
get_tp_pynccl_communicator) get_tp_pynccl_communicator)
@contextmanager @dataclass
def graph_mode(): class GraphCaptureContext:
# In graph mode, we have to be very careful about the collective stream: torch.cuda.Stream
# operations. The current status is:
# allreduce \ Mode | Eager | Graph |
# --------------------------------------------
# custom allreduce | enabled | enabled |
# PyNccl | disabled| enabled |
# torch.distributed | enabled | disabled|
#
# Note that custom allreduce will have a runtime check, if the tensor size
# is too large, it will fallback to the next available option.
# In summary: When using CUDA graph, we use
# either custom all-reduce kernel or pynccl. When not using CUDA
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
# We always prioritize using custom all-reduce kernel but fall back
# to PyTorch or pynccl if it is disabled or not supported.
pynccl_comm = get_tp_pynccl_communicator()
if pynccl_comm is None:
context = nullcontext()
else:
context = pynccl_comm.change_state(enable=True,
stream=torch.cuda.current_stream())
with context:
yield
@contextmanager @contextmanager
def graph_capture(): def graph_capture():
""" """
`graph_capture` is a context manager which should include the code that `graph_capture` is a context manager which should surround the code that
is capturing the CUDA graph. Its main purpose is to ensure that the is capturing the CUDA graph. Its main purpose is to ensure that the
some operations will be run after the graph is captured, before the graph some operations will be run after the graph is captured, before the graph
is replayed. is replayed. It returns a `GraphCaptureContext` object which contains the
necessary data for the graph capture. Currently, it only contains the
stream that the graph capture is running on. This stream is set to the
current CUDA stream when the context manager is entered and reset to the
default stream when the context manager is exited. This is to ensure that
the graph capture is running on a separate stream from the default stream,
in order to explicitly distinguish the kernels to capture
from other kernels possibly launched on background in the default stream.
""" """
stream = torch.cuda.Stream()
graph_capture_context = GraphCaptureContext(stream)
ca_comm = get_tp_ca_communicator() ca_comm = get_tp_ca_communicator()
context = nullcontext() if ca_comm is None else ca_comm.capture() maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture()
with context: with torch.cuda.stream(stream), maybe_ca_context:
yield # In graph mode, we have to be very careful about the collective
# operations. The current status is:
# allreduce \ Mode | Eager | Graph |
# --------------------------------------------
# custom allreduce | enabled | enabled |
# PyNccl | disabled| enabled |
# torch.distributed | enabled | disabled|
#
# Note that custom allreduce will have a runtime check, if the tensor
# size is too large, it will fallback to the next available option.
# In summary: When using CUDA graph, we use
# either custom all-reduce kernel or pynccl. When not using CUDA
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
# We always prioritize using custom all-reduce kernel but fall back
# to PyTorch or pynccl if it is disabled or not supported.
pynccl_comm = get_tp_pynccl_communicator()
if pynccl_comm is None:
maybe_pynccl_context = nullcontext()
else:
maybe_pynccl_context = pynccl_comm.change_state(
enable=True, stream=torch.cuda.current_stream())
with maybe_pynccl_context:
yield graph_capture_context
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:

View File

@ -10,7 +10,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig, ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig) VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict from vllm.distributed import broadcast_tensor_dict
from vllm.distributed.communication_op import graph_capture, graph_mode from vllm.distributed.communication_op import graph_capture
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
@ -841,7 +841,7 @@ class ModelRunner:
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
] ]
with graph_capture(): with graph_capture() as graph_capture_context:
# NOTE: Capturing the largest batch size first may help reduce the # NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph. # memory usage of CUDA graph.
for batch_size in reversed(batch_size_capture_list): for batch_size in reversed(batch_size_capture_list):
@ -877,6 +877,7 @@ class ModelRunner:
kv_caches, kv_caches,
attn_metadata, attn_metadata,
memory_pool=self.graph_memory_pool, memory_pool=self.graph_memory_pool,
stream=graph_capture_context.stream,
) )
self.graph_memory_pool = graph_runner.graph.pool() self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[batch_size] = graph_runner self.graph_runners[batch_size] = graph_runner
@ -921,15 +922,27 @@ class CUDAGraphRunner:
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
memory_pool, memory_pool: Optional[Tuple[int, int]],
stream: torch.cuda.Stream,
**kwargs, **kwargs,
) -> None: ) -> None:
assert self._graph is None assert self._graph is None
# Run the model once without capturing the graph. # Run the model once without capturing the graph.
# This is to make sure that the captured graph does not include the # This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune). # kernel launches for initial benchmarking (e.g., Triton autotune).
with graph_mode(): self.model(
self.model( input_ids,
positions,
kv_caches,
attn_metadata,
**kwargs,
)
torch.cuda.synchronize()
# Capture the graph.
self._graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
hidden_states = self.model(
input_ids, input_ids,
positions, positions,
kv_caches, kv_caches,
@ -938,21 +951,6 @@ class CUDAGraphRunner:
) )
torch.cuda.synchronize() torch.cuda.synchronize()
# Capture the graph.
# NOTE(woosuk): Python 3.8 does not support multi-line with statements.
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
self._graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117
with graph_mode():
hidden_states = self.model(
input_ids,
positions,
kv_caches,
attn_metadata,
**kwargs,
)
torch.cuda.synchronize()
# Save the input and output buffers. # Save the input and output buffers.
self.input_buffers = { self.input_buffers = {
"input_ids": input_ids, "input_ids": input_ids,