[PCP&DCP] move CUDAGraph check for PCP&DCP to the check func of platforms (#29952)

Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Qiu 2025-12-05 10:40:51 +08:00 committed by GitHub
parent befb59e5b1
commit 0098a6e3da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 43 additions and 22 deletions

View File

@ -671,36 +671,22 @@ class VllmConfig:
if current_platform.support_static_graph_mode():
# if cudagraph_mode has full cudagraphs, we need to check support
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
# decode context parallel does not support full cudagraphs
if self.parallel_config.decode_context_parallel_size > 1:
if (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
and self.model_config is not None
):
if self.model_config.pooler_config is not None:
logger.warning_once(
"Decode context parallel (DCP) is enabled, which is "
"incompatible with full CUDA graphs. "
"Pooling models do not support full cudagraphs. "
"Overriding cudagraph_mode to PIECEWISE."
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
# prefill context parallel do not support full cudagraphs
elif self.parallel_config.prefill_context_parallel_size > 1:
elif self.model_config.is_encoder_decoder:
logger.warning_once(
"Prefill context parallel (PCP) is enabled, which is "
"incompatible with full CUDA graphs. "
"Encoder-decoder models do not support full cudagraphs. "
"Overriding cudagraph_mode to PIECEWISE."
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
elif self.model_config is not None:
if self.model_config.pooler_config is not None:
logger.warning_once(
"Pooling models do not support full cudagraphs. "
"Overriding cudagraph_mode to PIECEWISE."
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
elif self.model_config.is_encoder_decoder:
logger.warning_once(
"Encoder-decoder models do not support full cudagraphs. "
"Overriding cudagraph_mode to PIECEWISE."
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
# disable cudagraph when enforce eager execution
if self.model_config is not None and self.model_config.enforce_eager:

View File

@ -233,6 +233,23 @@ class CudaPlatformBase(Platform):
from vllm.config import CUDAGraphMode
compilation_config = vllm_config.compilation_config
if compilation_config.cudagraph_mode.has_full_cudagraphs():
# decode context parallel does not support full cudagraphs
if parallel_config.decode_context_parallel_size > 1:
logger.warning_once(
"Decode context parallel (DCP) is enabled, which is "
"incompatible with full CUDA graphs. "
"Overriding cudagraph_mode to PIECEWISE."
)
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
# prefill context parallel do not support full cudagraphs
elif parallel_config.prefill_context_parallel_size > 1:
logger.warning_once(
"Prefill context parallel (PCP) is enabled, which is "
"incompatible with full CUDA graphs. "
"Overriding cudagraph_mode to PIECEWISE."
)
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
if (
parallel_config.all2all_backend == "deepep_high_throughput"
and parallel_config.data_parallel_size > 1

View File

@ -381,6 +381,24 @@ class RocmPlatform(Platform):
parallel_config = vllm_config.parallel_config
is_eager_execution = compilation_config == CUDAGraphMode.NONE
if compilation_config.cudagraph_mode.has_full_cudagraphs():
# decode context parallel does not support full cudagraphs
if parallel_config.decode_context_parallel_size > 1:
logger.warning_once(
"Decode context parallel (DCP) is enabled, which is "
"incompatible with full CUDA graphs. "
"Overriding cudagraph_mode to PIECEWISE."
)
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
# prefill context parallel do not support full cudagraphs
elif parallel_config.prefill_context_parallel_size > 1:
logger.warning_once(
"Prefill context parallel (PCP) is enabled, which is "
"incompatible with full CUDA graphs. "
"Overriding cudagraph_mode to PIECEWISE."
)
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
if cache_config and cache_config.block_size is None: