mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 23:35:52 +08:00
[Core][Distributed] remove graph mode function (#4818)
This commit is contained in:
parent
b5853f9963
commit
e08188081b
@ -50,7 +50,7 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
|
|||||||
|
|
||||||
for sz in test_sizes:
|
for sz in test_sizes:
|
||||||
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||||
with graph_capture():
|
with graph_capture() as graph_capture_context:
|
||||||
# use integers so result matches NCCL exactly
|
# use integers so result matches NCCL exactly
|
||||||
inp1 = torch.randint(1,
|
inp1 = torch.randint(1,
|
||||||
16, (sz, ),
|
16, (sz, ),
|
||||||
@ -62,7 +62,8 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
|
|||||||
device=torch.cuda.current_device())
|
device=torch.cuda.current_device())
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
with torch.cuda.graph(graph):
|
with torch.cuda.graph(graph,
|
||||||
|
stream=graph_capture_context.stream):
|
||||||
for i in range(num_communication):
|
for i in range(num_communication):
|
||||||
out1 = tensor_model_parallel_all_reduce(inp1)
|
out1 = tensor_model_parallel_all_reduce(inp1)
|
||||||
# the input buffer is immediately modified to test
|
# the input buffer is immediately modified to test
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.distributed.communication_op import ( # noqa
|
from vllm.distributed.communication_op import ( # noqa
|
||||||
graph_mode, tensor_model_parallel_all_reduce)
|
graph_capture, tensor_model_parallel_all_reduce)
|
||||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||||
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
|
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
|
||||||
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
||||||
@ -103,7 +103,7 @@ def multiple_tp_with_vllm_worker_fn():
|
|||||||
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
||||||
ensure_model_parallel_initialized(2, 2)
|
ensure_model_parallel_initialized(2, 2)
|
||||||
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
|
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
|
||||||
with graph_mode():
|
with graph_capture():
|
||||||
# two tp groups can communicate independently
|
# two tp groups can communicate independently
|
||||||
if torch.distributed.get_rank() in [0, 1]:
|
if torch.distributed.get_rank() in [0, 1]:
|
||||||
tensor = tensor_model_parallel_all_reduce(tensor)
|
tensor = tensor_model_parallel_all_reduce(tensor)
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -13,45 +14,54 @@ from .parallel_state import (get_cpu_world_group,
|
|||||||
get_tp_pynccl_communicator)
|
get_tp_pynccl_communicator)
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@dataclass
|
||||||
def graph_mode():
|
class GraphCaptureContext:
|
||||||
# In graph mode, we have to be very careful about the collective
|
stream: torch.cuda.Stream
|
||||||
# operations. The current status is:
|
|
||||||
# allreduce \ Mode | Eager | Graph |
|
|
||||||
# --------------------------------------------
|
|
||||||
# custom allreduce | enabled | enabled |
|
|
||||||
# PyNccl | disabled| enabled |
|
|
||||||
# torch.distributed | enabled | disabled|
|
|
||||||
#
|
|
||||||
# Note that custom allreduce will have a runtime check, if the tensor size
|
|
||||||
# is too large, it will fallback to the next available option.
|
|
||||||
# In summary: When using CUDA graph, we use
|
|
||||||
# either custom all-reduce kernel or pynccl. When not using CUDA
|
|
||||||
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
|
|
||||||
# We always prioritize using custom all-reduce kernel but fall back
|
|
||||||
# to PyTorch or pynccl if it is disabled or not supported.
|
|
||||||
pynccl_comm = get_tp_pynccl_communicator()
|
|
||||||
if pynccl_comm is None:
|
|
||||||
context = nullcontext()
|
|
||||||
else:
|
|
||||||
context = pynccl_comm.change_state(enable=True,
|
|
||||||
stream=torch.cuda.current_stream())
|
|
||||||
with context:
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def graph_capture():
|
def graph_capture():
|
||||||
"""
|
"""
|
||||||
`graph_capture` is a context manager which should include the code that
|
`graph_capture` is a context manager which should surround the code that
|
||||||
is capturing the CUDA graph. Its main purpose is to ensure that the
|
is capturing the CUDA graph. Its main purpose is to ensure that the
|
||||||
some operations will be run after the graph is captured, before the graph
|
some operations will be run after the graph is captured, before the graph
|
||||||
is replayed.
|
is replayed. It returns a `GraphCaptureContext` object which contains the
|
||||||
|
necessary data for the graph capture. Currently, it only contains the
|
||||||
|
stream that the graph capture is running on. This stream is set to the
|
||||||
|
current CUDA stream when the context manager is entered and reset to the
|
||||||
|
default stream when the context manager is exited. This is to ensure that
|
||||||
|
the graph capture is running on a separate stream from the default stream,
|
||||||
|
in order to explicitly distinguish the kernels to capture
|
||||||
|
from other kernels possibly launched on background in the default stream.
|
||||||
"""
|
"""
|
||||||
|
stream = torch.cuda.Stream()
|
||||||
|
graph_capture_context = GraphCaptureContext(stream)
|
||||||
ca_comm = get_tp_ca_communicator()
|
ca_comm = get_tp_ca_communicator()
|
||||||
context = nullcontext() if ca_comm is None else ca_comm.capture()
|
maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture()
|
||||||
with context:
|
with torch.cuda.stream(stream), maybe_ca_context:
|
||||||
yield
|
# In graph mode, we have to be very careful about the collective
|
||||||
|
# operations. The current status is:
|
||||||
|
# allreduce \ Mode | Eager | Graph |
|
||||||
|
# --------------------------------------------
|
||||||
|
# custom allreduce | enabled | enabled |
|
||||||
|
# PyNccl | disabled| enabled |
|
||||||
|
# torch.distributed | enabled | disabled|
|
||||||
|
#
|
||||||
|
# Note that custom allreduce will have a runtime check, if the tensor
|
||||||
|
# size is too large, it will fallback to the next available option.
|
||||||
|
# In summary: When using CUDA graph, we use
|
||||||
|
# either custom all-reduce kernel or pynccl. When not using CUDA
|
||||||
|
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
|
||||||
|
# We always prioritize using custom all-reduce kernel but fall back
|
||||||
|
# to PyTorch or pynccl if it is disabled or not supported.
|
||||||
|
pynccl_comm = get_tp_pynccl_communicator()
|
||||||
|
if pynccl_comm is None:
|
||||||
|
maybe_pynccl_context = nullcontext()
|
||||||
|
else:
|
||||||
|
maybe_pynccl_context = pynccl_comm.change_state(
|
||||||
|
enable=True, stream=torch.cuda.current_stream())
|
||||||
|
with maybe_pynccl_context:
|
||||||
|
yield graph_capture_context
|
||||||
|
|
||||||
|
|
||||||
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
|||||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||||
VisionLanguageConfig)
|
VisionLanguageConfig)
|
||||||
from vllm.distributed import broadcast_tensor_dict
|
from vllm.distributed import broadcast_tensor_dict
|
||||||
from vllm.distributed.communication_op import graph_capture, graph_mode
|
from vllm.distributed.communication_op import graph_capture
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.layers import LoRAMapping
|
from vllm.lora.layers import LoRAMapping
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
@ -841,7 +841,7 @@ class ModelRunner:
|
|||||||
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
|
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
|
||||||
]
|
]
|
||||||
|
|
||||||
with graph_capture():
|
with graph_capture() as graph_capture_context:
|
||||||
# NOTE: Capturing the largest batch size first may help reduce the
|
# NOTE: Capturing the largest batch size first may help reduce the
|
||||||
# memory usage of CUDA graph.
|
# memory usage of CUDA graph.
|
||||||
for batch_size in reversed(batch_size_capture_list):
|
for batch_size in reversed(batch_size_capture_list):
|
||||||
@ -877,6 +877,7 @@ class ModelRunner:
|
|||||||
kv_caches,
|
kv_caches,
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
memory_pool=self.graph_memory_pool,
|
memory_pool=self.graph_memory_pool,
|
||||||
|
stream=graph_capture_context.stream,
|
||||||
)
|
)
|
||||||
self.graph_memory_pool = graph_runner.graph.pool()
|
self.graph_memory_pool = graph_runner.graph.pool()
|
||||||
self.graph_runners[batch_size] = graph_runner
|
self.graph_runners[batch_size] = graph_runner
|
||||||
@ -921,15 +922,27 @@ class CUDAGraphRunner:
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
memory_pool,
|
memory_pool: Optional[Tuple[int, int]],
|
||||||
|
stream: torch.cuda.Stream,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert self._graph is None
|
assert self._graph is None
|
||||||
# Run the model once without capturing the graph.
|
# Run the model once without capturing the graph.
|
||||||
# This is to make sure that the captured graph does not include the
|
# This is to make sure that the captured graph does not include the
|
||||||
# kernel launches for initial benchmarking (e.g., Triton autotune).
|
# kernel launches for initial benchmarking (e.g., Triton autotune).
|
||||||
with graph_mode():
|
self.model(
|
||||||
self.model(
|
input_ids,
|
||||||
|
positions,
|
||||||
|
kv_caches,
|
||||||
|
attn_metadata,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Capture the graph.
|
||||||
|
self._graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
|
||||||
|
hidden_states = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
positions,
|
positions,
|
||||||
kv_caches,
|
kv_caches,
|
||||||
@ -938,21 +951,6 @@ class CUDAGraphRunner:
|
|||||||
)
|
)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
# Capture the graph.
|
|
||||||
# NOTE(woosuk): Python 3.8 does not support multi-line with statements.
|
|
||||||
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
|
|
||||||
self._graph = torch.cuda.CUDAGraph()
|
|
||||||
with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117
|
|
||||||
with graph_mode():
|
|
||||||
hidden_states = self.model(
|
|
||||||
input_ids,
|
|
||||||
positions,
|
|
||||||
kv_caches,
|
|
||||||
attn_metadata,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
# Save the input and output buffers.
|
# Save the input and output buffers.
|
||||||
self.input_buffers = {
|
self.input_buffers = {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user