mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:25:01 +08:00
[Core][Distributed] remove graph mode function (#4818)
This commit is contained in:
parent
b5853f9963
commit
e08188081b
@ -50,7 +50,7 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
|
||||
|
||||
for sz in test_sizes:
|
||||
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
with graph_capture():
|
||||
with graph_capture() as graph_capture_context:
|
||||
# use integers so result matches NCCL exactly
|
||||
inp1 = torch.randint(1,
|
||||
16, (sz, ),
|
||||
@ -62,7 +62,8 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
|
||||
device=torch.cuda.current_device())
|
||||
torch.cuda.synchronize()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
with torch.cuda.graph(graph,
|
||||
stream=graph_capture_context.stream):
|
||||
for i in range(num_communication):
|
||||
out1 = tensor_model_parallel_all_reduce(inp1)
|
||||
# the input buffer is immediately modified to test
|
||||
|
||||
@ -5,7 +5,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.distributed.communication_op import ( # noqa
|
||||
graph_mode, tensor_model_parallel_all_reduce)
|
||||
graph_capture, tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
|
||||
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
||||
@ -103,7 +103,7 @@ def multiple_tp_with_vllm_worker_fn():
|
||||
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
||||
ensure_model_parallel_initialized(2, 2)
|
||||
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
|
||||
with graph_mode():
|
||||
with graph_capture():
|
||||
# two tp groups can communicate independently
|
||||
if torch.distributed.get_rank() in [0, 1]:
|
||||
tensor = tensor_model_parallel_all_reduce(tensor)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from collections import namedtuple
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@ -13,45 +14,54 @@ from .parallel_state import (get_cpu_world_group,
|
||||
get_tp_pynccl_communicator)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def graph_mode():
|
||||
# In graph mode, we have to be very careful about the collective
|
||||
# operations. The current status is:
|
||||
# allreduce \ Mode | Eager | Graph |
|
||||
# --------------------------------------------
|
||||
# custom allreduce | enabled | enabled |
|
||||
# PyNccl | disabled| enabled |
|
||||
# torch.distributed | enabled | disabled|
|
||||
#
|
||||
# Note that custom allreduce will have a runtime check, if the tensor size
|
||||
# is too large, it will fallback to the next available option.
|
||||
# In summary: When using CUDA graph, we use
|
||||
# either custom all-reduce kernel or pynccl. When not using CUDA
|
||||
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
|
||||
# We always prioritize using custom all-reduce kernel but fall back
|
||||
# to PyTorch or pynccl if it is disabled or not supported.
|
||||
pynccl_comm = get_tp_pynccl_communicator()
|
||||
if pynccl_comm is None:
|
||||
context = nullcontext()
|
||||
else:
|
||||
context = pynccl_comm.change_state(enable=True,
|
||||
stream=torch.cuda.current_stream())
|
||||
with context:
|
||||
yield
|
||||
@dataclass
|
||||
class GraphCaptureContext:
|
||||
stream: torch.cuda.Stream
|
||||
|
||||
|
||||
@contextmanager
|
||||
def graph_capture():
|
||||
"""
|
||||
`graph_capture` is a context manager which should include the code that
|
||||
`graph_capture` is a context manager which should surround the code that
|
||||
is capturing the CUDA graph. Its main purpose is to ensure that the
|
||||
some operations will be run after the graph is captured, before the graph
|
||||
is replayed.
|
||||
is replayed. It returns a `GraphCaptureContext` object which contains the
|
||||
necessary data for the graph capture. Currently, it only contains the
|
||||
stream that the graph capture is running on. This stream is set to the
|
||||
current CUDA stream when the context manager is entered and reset to the
|
||||
default stream when the context manager is exited. This is to ensure that
|
||||
the graph capture is running on a separate stream from the default stream,
|
||||
in order to explicitly distinguish the kernels to capture
|
||||
from other kernels possibly launched on background in the default stream.
|
||||
"""
|
||||
stream = torch.cuda.Stream()
|
||||
graph_capture_context = GraphCaptureContext(stream)
|
||||
ca_comm = get_tp_ca_communicator()
|
||||
context = nullcontext() if ca_comm is None else ca_comm.capture()
|
||||
with context:
|
||||
yield
|
||||
maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture()
|
||||
with torch.cuda.stream(stream), maybe_ca_context:
|
||||
# In graph mode, we have to be very careful about the collective
|
||||
# operations. The current status is:
|
||||
# allreduce \ Mode | Eager | Graph |
|
||||
# --------------------------------------------
|
||||
# custom allreduce | enabled | enabled |
|
||||
# PyNccl | disabled| enabled |
|
||||
# torch.distributed | enabled | disabled|
|
||||
#
|
||||
# Note that custom allreduce will have a runtime check, if the tensor
|
||||
# size is too large, it will fallback to the next available option.
|
||||
# In summary: When using CUDA graph, we use
|
||||
# either custom all-reduce kernel or pynccl. When not using CUDA
|
||||
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
|
||||
# We always prioritize using custom all-reduce kernel but fall back
|
||||
# to PyTorch or pynccl if it is disabled or not supported.
|
||||
pynccl_comm = get_tp_pynccl_communicator()
|
||||
if pynccl_comm is None:
|
||||
maybe_pynccl_context = nullcontext()
|
||||
else:
|
||||
maybe_pynccl_context = pynccl_comm.change_state(
|
||||
enable=True, stream=torch.cuda.current_stream())
|
||||
with maybe_pynccl_context:
|
||||
yield graph_capture_context
|
||||
|
||||
|
||||
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@ -10,7 +10,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.distributed import broadcast_tensor_dict
|
||||
from vllm.distributed.communication_op import graph_capture, graph_mode
|
||||
from vllm.distributed.communication_op import graph_capture
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -841,7 +841,7 @@ class ModelRunner:
|
||||
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
|
||||
]
|
||||
|
||||
with graph_capture():
|
||||
with graph_capture() as graph_capture_context:
|
||||
# NOTE: Capturing the largest batch size first may help reduce the
|
||||
# memory usage of CUDA graph.
|
||||
for batch_size in reversed(batch_size_capture_list):
|
||||
@ -877,6 +877,7 @@ class ModelRunner:
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
memory_pool=self.graph_memory_pool,
|
||||
stream=graph_capture_context.stream,
|
||||
)
|
||||
self.graph_memory_pool = graph_runner.graph.pool()
|
||||
self.graph_runners[batch_size] = graph_runner
|
||||
@ -921,15 +922,27 @@ class CUDAGraphRunner:
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
memory_pool,
|
||||
memory_pool: Optional[Tuple[int, int]],
|
||||
stream: torch.cuda.Stream,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
assert self._graph is None
|
||||
# Run the model once without capturing the graph.
|
||||
# This is to make sure that the captured graph does not include the
|
||||
# kernel launches for initial benchmarking (e.g., Triton autotune).
|
||||
with graph_mode():
|
||||
self.model(
|
||||
self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
**kwargs,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Capture the graph.
|
||||
self._graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
@ -938,21 +951,6 @@ class CUDAGraphRunner:
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Capture the graph.
|
||||
# NOTE(woosuk): Python 3.8 does not support multi-line with statements.
|
||||
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
|
||||
self._graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117
|
||||
with graph_mode():
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
**kwargs,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Save the input and output buffers.
|
||||
self.input_buffers = {
|
||||
"input_ids": input_ids,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user