mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 09:15:40 +08:00
[Perf] Enable cuda graph for deepepHT, 5.3% throughput improvement, 4.4% TTFT improvement (#29558)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
dce6d229f7
commit
17eb25e327
@ -10,7 +10,7 @@ from pydantic import ValidationError
|
|||||||
|
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||||
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
from vllm.config import CompilationConfig, CUDAGraphMode, ParallelConfig, VllmConfig
|
||||||
from vllm.config.compilation import CompilationMode, PassConfig
|
from vllm.config.compilation import CompilationMode, PassConfig
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.logger import _print_warning_once
|
from vllm.logger import _print_warning_once
|
||||||
@ -235,6 +235,70 @@ def test_splitting_ops_dynamic():
|
|||||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
||||||
|
|
||||||
|
|
||||||
|
def test_moe_splitting_ops_deepep_ht_piecewise():
|
||||||
|
# Non-inductor, non-attn-fusion case: DeepEP HT with dp>1
|
||||||
|
# should add MoE ops to splitting_ops on top of attention ops.
|
||||||
|
config = VllmConfig(
|
||||||
|
parallel_config=ParallelConfig(
|
||||||
|
all2all_backend="deepep_high_throughput",
|
||||||
|
data_parallel_size=8,
|
||||||
|
),
|
||||||
|
compilation_config=CompilationConfig(
|
||||||
|
mode=CompilationMode.VLLM_COMPILE,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
splitting_ops = config.compilation_config.splitting_ops
|
||||||
|
assert splitting_ops is not None
|
||||||
|
assert "vllm::moe_forward" in splitting_ops
|
||||||
|
assert "vllm::moe_forward_shared" in splitting_ops
|
||||||
|
|
||||||
|
|
||||||
|
def test_moe_splitting_ops_deepep_ht_inductor_partition():
|
||||||
|
# Inductor partition case: user-provided splitting_ops should be
|
||||||
|
# preserved and MoE ops should be appended for DeepEP HT with dp>1.
|
||||||
|
config = VllmConfig(
|
||||||
|
parallel_config=ParallelConfig(
|
||||||
|
all2all_backend="deepep_high_throughput",
|
||||||
|
data_parallel_size=8,
|
||||||
|
),
|
||||||
|
compilation_config=CompilationConfig(
|
||||||
|
mode=CompilationMode.VLLM_COMPILE,
|
||||||
|
use_inductor_graph_partition=True,
|
||||||
|
splitting_ops=[
|
||||||
|
"vllm::unified_attention",
|
||||||
|
"vllm::moe_forward",
|
||||||
|
"vllm::moe_forward_shared",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
splitting_ops = config.compilation_config.splitting_ops
|
||||||
|
assert splitting_ops == [
|
||||||
|
"vllm::unified_attention",
|
||||||
|
"vllm::moe_forward",
|
||||||
|
"vllm::moe_forward_shared",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_moe_splitting_ops_deepep_ht_attn_fusion_no_inductor():
|
||||||
|
# Pure attn-fusion case without inductor partition: even with
|
||||||
|
# DeepEP HT and dp>1, we should not re-enable piecewise compilation
|
||||||
|
# or add MoE ops into splitting_ops.
|
||||||
|
config = VllmConfig(
|
||||||
|
parallel_config=ParallelConfig(
|
||||||
|
all2all_backend="deepep_high_throughput",
|
||||||
|
data_parallel_size=8,
|
||||||
|
),
|
||||||
|
compilation_config=CompilationConfig(
|
||||||
|
mode=CompilationMode.VLLM_COMPILE,
|
||||||
|
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
||||||
|
custom_ops=["+quant_fp8"],
|
||||||
|
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert config.compilation_config.splitting_ops == []
|
||||||
|
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL
|
||||||
|
|
||||||
|
|
||||||
def test_should_split():
|
def test_should_split():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|||||||
@ -966,7 +966,9 @@ class CompilationConfig:
|
|||||||
# May get recomputed in the model runner if adjustment is needed for spec-decode
|
# May get recomputed in the model runner if adjustment is needed for spec-decode
|
||||||
self.compute_bs_to_padded_graph_size()
|
self.compute_bs_to_padded_graph_size()
|
||||||
|
|
||||||
def set_splitting_ops_for_v1(self):
|
def set_splitting_ops_for_v1(
|
||||||
|
self, all2all_backend: str | None = None, data_parallel_size: int | None = None
|
||||||
|
):
|
||||||
# To compatible with OOT hardware plugin platform (for example vllm-ascend)
|
# To compatible with OOT hardware plugin platform (for example vllm-ascend)
|
||||||
# which currently only supports sequence parallelism in eager mode.
|
# which currently only supports sequence parallelism in eager mode.
|
||||||
if self.mode != CompilationMode.VLLM_COMPILE:
|
if self.mode != CompilationMode.VLLM_COMPILE:
|
||||||
@ -981,50 +983,83 @@ class CompilationConfig:
|
|||||||
"mode is CompilationMode.VLLM_COMPILE"
|
"mode is CompilationMode.VLLM_COMPILE"
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.use_inductor_graph_partition:
|
added_default_splitting_ops = False
|
||||||
self.set_splitting_ops_for_inductor_graph_partition()
|
|
||||||
return
|
|
||||||
|
|
||||||
if self.pass_config.fuse_attn_quant:
|
if self.pass_config.fuse_attn_quant and not self.use_inductor_graph_partition:
|
||||||
# here use_inductor_graph_partition is False
|
|
||||||
self.set_splitting_ops_for_attn_fusion()
|
self.set_splitting_ops_for_attn_fusion()
|
||||||
return
|
else:
|
||||||
|
if self.splitting_ops is None:
|
||||||
if self.splitting_ops is None:
|
# NOTE: When using full cudagraph, instead of setting an empty
|
||||||
# NOTE: When using full cudagraph, instead of setting an empty
|
# list and capture the full cudagraph inside the flattened fx
|
||||||
# list and capture the full cudagraph inside the flattened fx
|
# graph, we keep the piecewise fx graph structure but capture
|
||||||
# graph, we keep the piecewise fx graph structure but capture
|
# the full cudagraph outside the fx graph. This reduces some
|
||||||
# the full cudagraph outside the fx graph. This reduces some
|
# cpu overhead when the runtime batch_size is not cudagraph
|
||||||
# cpu overhead when the runtime batch_size is not cudagraph
|
# captured. see https://github.com/vllm-project/vllm/pull/20059
|
||||||
# captured. see https://github.com/vllm-project/vllm/pull/20059
|
# for details. Make a copy to avoid mutating the class-level
|
||||||
# for details. Make a copy to avoid mutating the class-level
|
# list via reference.
|
||||||
# list via reference.
|
self.splitting_ops = list(self._attention_ops)
|
||||||
self.splitting_ops = list(self._attention_ops)
|
added_default_splitting_ops = True
|
||||||
elif len(self.splitting_ops) == 0:
|
elif len(self.splitting_ops) == 0:
|
||||||
logger.warning_once("Using piecewise compilation with empty splitting_ops")
|
|
||||||
if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
|
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"Piecewise compilation with empty splitting_ops do not"
|
"Using piecewise compilation with empty splitting_ops"
|
||||||
"contains piecewise cudagraph. Setting cudagraph_"
|
|
||||||
"mode to NONE. Hint: If you are using attention backends "
|
|
||||||
"that support cudagraph, consider manually setting "
|
|
||||||
"cudagraph_mode to FULL or FULL_DECODE_ONLY to enable "
|
|
||||||
"full cudagraphs."
|
|
||||||
)
|
)
|
||||||
|
if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
|
||||||
|
logger.warning_once(
|
||||||
|
"Piecewise compilation with empty splitting_ops do not"
|
||||||
|
"contains piecewise cudagraph. Setting cudagraph_"
|
||||||
|
"mode to NONE. Hint: If you are using attention "
|
||||||
|
"backends that support cudagraph, consider manually "
|
||||||
|
"setting cudagraph_mode to FULL or FULL_DECODE_ONLY "
|
||||||
|
"to enable full cudagraphs."
|
||||||
|
)
|
||||||
|
self.cudagraph_mode = CUDAGraphMode.NONE
|
||||||
|
elif self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
|
||||||
|
logger.warning_once(
|
||||||
|
"Piecewise compilation with empty splitting_ops do "
|
||||||
|
"not contains piecewise cudagraph. Setting "
|
||||||
|
"cudagraph_mode to FULL."
|
||||||
|
)
|
||||||
|
self.cudagraph_mode = CUDAGraphMode.FULL
|
||||||
|
self.splitting_ops = []
|
||||||
|
|
||||||
|
# split MoE ops for cudagraph
|
||||||
|
moe_ops = [
|
||||||
|
"vllm::moe_forward",
|
||||||
|
"vllm::moe_forward_shared",
|
||||||
|
]
|
||||||
|
backend = all2all_backend or envs.VLLM_ALL2ALL_BACKEND
|
||||||
|
dp_size = data_parallel_size if data_parallel_size is not None else 1
|
||||||
|
need_moe_splitting = (
|
||||||
|
backend == "deepep_high_throughput"
|
||||||
|
and dp_size > 1
|
||||||
|
# pure attn-fusion without inductor partition deliberately disables
|
||||||
|
# piecewise graphs and MoE splitting.
|
||||||
|
and not (
|
||||||
|
self.pass_config.fuse_attn_quant
|
||||||
|
and not self.use_inductor_graph_partition
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if need_moe_splitting and self.cudagraph_mode != CUDAGraphMode.NONE:
|
||||||
|
# if we just initialized default splitting_ops for this config,
|
||||||
|
# automatically append the MoE ops
|
||||||
|
if added_default_splitting_ops:
|
||||||
|
for op in moe_ops:
|
||||||
|
if op not in self.splitting_ops:
|
||||||
|
self.splitting_ops.append(op)
|
||||||
|
|
||||||
|
# make sure MoE ops are split out
|
||||||
|
if not any(op in self.splitting_ops for op in moe_ops):
|
||||||
self.cudagraph_mode = CUDAGraphMode.NONE
|
self.cudagraph_mode = CUDAGraphMode.NONE
|
||||||
elif self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
|
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"Piecewise compilation with empty splitting_ops do not "
|
"DeepEP high throughput backend with data_parallel_size > 1 "
|
||||||
"contains piecewise cudagraph. Setting cudagraph_mode "
|
"requires splitting MoE ops from cudagraphs. Please ensure "
|
||||||
"to FULL."
|
"'vllm::moe_forward' or 'vllm::moe_forward_shared' are "
|
||||||
|
"present in CompilationConfig.splitting_ops."
|
||||||
)
|
)
|
||||||
self.cudagraph_mode = CUDAGraphMode.FULL
|
elif self.cudagraph_mode.has_full_cudagraphs():
|
||||||
self.splitting_ops = []
|
# fall back to piecewise when MoE splitting is required.
|
||||||
|
self.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
||||||
def set_splitting_ops_for_inductor_graph_partition(self):
|
|
||||||
assert self.use_inductor_graph_partition
|
|
||||||
if self.splitting_ops is None:
|
|
||||||
self.splitting_ops = list(self._attention_ops)
|
|
||||||
|
|
||||||
def set_splitting_ops_for_attn_fusion(self):
|
def set_splitting_ops_for_attn_fusion(self):
|
||||||
assert self.pass_config.fuse_attn_quant
|
assert self.pass_config.fuse_attn_quant
|
||||||
|
|||||||
@ -813,7 +813,10 @@ class VllmConfig:
|
|||||||
), "MTP with cp_kv_cache_interleave_size > 1 is not supported now."
|
), "MTP with cp_kv_cache_interleave_size > 1 is not supported now."
|
||||||
|
|
||||||
# Do this after all the updates to compilation_config.mode
|
# Do this after all the updates to compilation_config.mode
|
||||||
self.compilation_config.set_splitting_ops_for_v1()
|
self.compilation_config.set_splitting_ops_for_v1(
|
||||||
|
all2all_backend=self.parallel_config.all2all_backend,
|
||||||
|
data_parallel_size=self.parallel_config.data_parallel_size,
|
||||||
|
)
|
||||||
|
|
||||||
if self.compilation_config.pass_config.enable_sp:
|
if self.compilation_config.pass_config.enable_sp:
|
||||||
# With pipeline parallelism or dynamo partitioning,
|
# With pipeline parallelism or dynamo partitioning,
|
||||||
|
|||||||
@ -232,44 +232,6 @@ class CudaPlatformBase(Platform):
|
|||||||
logger.info(
|
logger.info(
|
||||||
"Forcing kv cache block size to 64 for FlashMLASparse backend."
|
"Forcing kv cache block size to 64 for FlashMLASparse backend."
|
||||||
)
|
)
|
||||||
# lazy import to avoid circular import
|
|
||||||
from vllm.config import CUDAGraphMode
|
|
||||||
|
|
||||||
compilation_config = vllm_config.compilation_config
|
|
||||||
if compilation_config.cudagraph_mode.has_full_cudagraphs():
|
|
||||||
# decode context parallel does not support full cudagraphs
|
|
||||||
if parallel_config.decode_context_parallel_size > 1:
|
|
||||||
logger.warning_once(
|
|
||||||
"Decode context parallel (DCP) is enabled, which is "
|
|
||||||
"incompatible with full CUDA graphs. "
|
|
||||||
"Overriding cudagraph_mode to PIECEWISE."
|
|
||||||
)
|
|
||||||
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
|
||||||
# prefill context parallel do not support full cudagraphs
|
|
||||||
elif parallel_config.prefill_context_parallel_size > 1:
|
|
||||||
logger.warning_once(
|
|
||||||
"Prefill context parallel (PCP) is enabled, which is "
|
|
||||||
"incompatible with full CUDA graphs. "
|
|
||||||
"Overriding cudagraph_mode to PIECEWISE."
|
|
||||||
)
|
|
||||||
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
|
||||||
if (
|
|
||||||
parallel_config.all2all_backend == "deepep_high_throughput"
|
|
||||||
and parallel_config.data_parallel_size > 1
|
|
||||||
and compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
|
||||||
):
|
|
||||||
# TODO: Piecewise Cuda graph might be enabled
|
|
||||||
# if torch compile cache key issue fixed
|
|
||||||
# See https://github.com/vllm-project/vllm/pull/25093
|
|
||||||
logger.info(
|
|
||||||
"WideEP: Disabling CUDA Graphs since DeepEP high-throughput "
|
|
||||||
"kernels are optimized for prefill and are incompatible with "
|
|
||||||
"CUDA Graphs. "
|
|
||||||
"In order to use CUDA Graphs for decode-optimized workloads, "
|
|
||||||
"use --all2all-backend with another option, such as "
|
|
||||||
"deepep_low_latency, pplx, or allgather_reducescatter."
|
|
||||||
)
|
|
||||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_current_memory_usage(
|
def get_current_memory_usage(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user