[Core][Distributed] remove graph mode function (#4818)

This commit is contained in:
youkaichao 2024-05-16 10:59:52 -07:00 committed by GitHub
parent b5853f9963
commit e08188081b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 63 additions and 54 deletions

View File

@ -50,7 +50,7 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
for sz in test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
with graph_capture():
with graph_capture() as graph_capture_context:
# use integers so result matches NCCL exactly
inp1 = torch.randint(1,
16, (sz, ),
@ -62,7 +62,8 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
device=torch.cuda.current_device())
torch.cuda.synchronize()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
with torch.cuda.graph(graph,
stream=graph_capture_context.stream):
for i in range(num_communication):
out1 = tensor_model_parallel_all_reduce(inp1)
# the input buffer is immediately modified to test

View File

@ -5,7 +5,7 @@ import pytest
import torch
from vllm.distributed.communication_op import ( # noqa
graph_mode, tensor_model_parallel_all_reduce)
graph_capture, tensor_model_parallel_all_reduce)
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
@ -103,7 +103,7 @@ def multiple_tp_with_vllm_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
ensure_model_parallel_initialized(2, 2)
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
with graph_mode():
with graph_capture():
# two tp groups can communicate independently
if torch.distributed.get_rank() in [0, 1]:
tensor = tensor_model_parallel_all_reduce(tensor)

View File

@ -1,5 +1,6 @@
from collections import namedtuple
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
@ -13,45 +14,54 @@ from .parallel_state import (get_cpu_world_group,
get_tp_pynccl_communicator)
@contextmanager
def graph_mode():
# In graph mode, we have to be very careful about the collective
# operations. The current status is:
# allreduce \ Mode | Eager | Graph |
# --------------------------------------------
# custom allreduce | enabled | enabled |
# PyNccl | disabled| enabled |
# torch.distributed | enabled | disabled|
#
# Note that custom allreduce will have a runtime check, if the tensor size
# is too large, it will fallback to the next available option.
# In summary: When using CUDA graph, we use
# either custom all-reduce kernel or pynccl. When not using CUDA
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
# We always prioritize using custom all-reduce kernel but fall back
# to PyTorch or pynccl if it is disabled or not supported.
pynccl_comm = get_tp_pynccl_communicator()
if pynccl_comm is None:
context = nullcontext()
else:
context = pynccl_comm.change_state(enable=True,
stream=torch.cuda.current_stream())
with context:
yield
@dataclass
class GraphCaptureContext:
stream: torch.cuda.Stream
@contextmanager
def graph_capture():
"""
`graph_capture` is a context manager which should include the code that
`graph_capture` is a context manager which should surround the code that
is capturing the CUDA graph. Its main purpose is to ensure that the
some operations will be run after the graph is captured, before the graph
is replayed.
is replayed. It returns a `GraphCaptureContext` object which contains the
necessary data for the graph capture. Currently, it only contains the
stream that the graph capture is running on. This stream is set to the
current CUDA stream when the context manager is entered and reset to the
default stream when the context manager is exited. This is to ensure that
the graph capture is running on a separate stream from the default stream,
in order to explicitly distinguish the kernels to capture
from other kernels possibly launched on background in the default stream.
"""
stream = torch.cuda.Stream()
graph_capture_context = GraphCaptureContext(stream)
ca_comm = get_tp_ca_communicator()
context = nullcontext() if ca_comm is None else ca_comm.capture()
with context:
yield
maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture()
with torch.cuda.stream(stream), maybe_ca_context:
# In graph mode, we have to be very careful about the collective
# operations. The current status is:
# allreduce \ Mode | Eager | Graph |
# --------------------------------------------
# custom allreduce | enabled | enabled |
# PyNccl | disabled| enabled |
# torch.distributed | enabled | disabled|
#
# Note that custom allreduce will have a runtime check, if the tensor
# size is too large, it will fallback to the next available option.
# In summary: When using CUDA graph, we use
# either custom all-reduce kernel or pynccl. When not using CUDA
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
# We always prioritize using custom all-reduce kernel but fall back
# to PyTorch or pynccl if it is disabled or not supported.
pynccl_comm = get_tp_pynccl_communicator()
if pynccl_comm is None:
maybe_pynccl_context = nullcontext()
else:
maybe_pynccl_context = pynccl_comm.change_state(
enable=True, stream=torch.cuda.current_stream())
with maybe_pynccl_context:
yield graph_capture_context
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:

View File

@ -10,7 +10,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict
from vllm.distributed.communication_op import graph_capture, graph_mode
from vllm.distributed.communication_op import graph_capture
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
@ -841,7 +841,7 @@ class ModelRunner:
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
]
with graph_capture():
with graph_capture() as graph_capture_context:
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
for batch_size in reversed(batch_size_capture_list):
@ -877,6 +877,7 @@ class ModelRunner:
kv_caches,
attn_metadata,
memory_pool=self.graph_memory_pool,
stream=graph_capture_context.stream,
)
self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[batch_size] = graph_runner
@ -921,15 +922,27 @@ class CUDAGraphRunner:
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
memory_pool,
memory_pool: Optional[Tuple[int, int]],
stream: torch.cuda.Stream,
**kwargs,
) -> None:
assert self._graph is None
# Run the model once without capturing the graph.
# This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune).
with graph_mode():
self.model(
self.model(
input_ids,
positions,
kv_caches,
attn_metadata,
**kwargs,
)
torch.cuda.synchronize()
# Capture the graph.
self._graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
hidden_states = self.model(
input_ids,
positions,
kv_caches,
@ -938,21 +951,6 @@ class CUDAGraphRunner:
)
torch.cuda.synchronize()
# Capture the graph.
# NOTE(woosuk): Python 3.8 does not support multi-line with statements.
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
self._graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117
with graph_mode():
hidden_states = self.model(
input_ids,
positions,
kv_caches,
attn_metadata,
**kwargs,
)
torch.cuda.synchronize()
# Save the input and output buffers.
self.input_buffers = {
"input_ids": input_ids,