diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 735b0afbaaeb..823bd96db9ac 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -671,36 +671,22 @@ class VllmConfig: if current_platform.support_static_graph_mode(): # if cudagraph_mode has full cudagraphs, we need to check support - if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): - # decode context parallel does not support full cudagraphs - if self.parallel_config.decode_context_parallel_size > 1: + if ( + self.compilation_config.cudagraph_mode.has_full_cudagraphs() + and self.model_config is not None + ): + if self.model_config.pooler_config is not None: logger.warning_once( - "Decode context parallel (DCP) is enabled, which is " - "incompatible with full CUDA graphs. " + "Pooling models do not support full cudagraphs. " "Overriding cudagraph_mode to PIECEWISE." ) self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE - # prefill context parallel do not support full cudagraphs - elif self.parallel_config.prefill_context_parallel_size > 1: + elif self.model_config.is_encoder_decoder: logger.warning_once( - "Prefill context parallel (PCP) is enabled, which is " - "incompatible with full CUDA graphs. " + "Encoder-decoder models do not support full cudagraphs. " "Overriding cudagraph_mode to PIECEWISE." ) self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE - elif self.model_config is not None: - if self.model_config.pooler_config is not None: - logger.warning_once( - "Pooling models do not support full cudagraphs. " - "Overriding cudagraph_mode to PIECEWISE." - ) - self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE - elif self.model_config.is_encoder_decoder: - logger.warning_once( - "Encoder-decoder models do not support full cudagraphs. " - "Overriding cudagraph_mode to PIECEWISE." - ) - self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE # disable cudagraph when enforce eager execution if self.model_config is not None and self.model_config.enforce_eager: diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 4bf9401b6b05..1467ca71efec 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -233,6 +233,23 @@ class CudaPlatformBase(Platform): from vllm.config import CUDAGraphMode compilation_config = vllm_config.compilation_config + if compilation_config.cudagraph_mode.has_full_cudagraphs(): + # decode context parallel does not support full cudagraphs + if parallel_config.decode_context_parallel_size > 1: + logger.warning_once( + "Decode context parallel (DCP) is enabled, which is " + "incompatible with full CUDA graphs. " + "Overriding cudagraph_mode to PIECEWISE." + ) + compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE + # prefill context parallel do not support full cudagraphs + elif parallel_config.prefill_context_parallel_size > 1: + logger.warning_once( + "Prefill context parallel (PCP) is enabled, which is " + "incompatible with full CUDA graphs. " + "Overriding cudagraph_mode to PIECEWISE." + ) + compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE if ( parallel_config.all2all_backend == "deepep_high_throughput" and parallel_config.data_parallel_size > 1 diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index ccf3446a3a6e..32c7f8e53663 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -381,6 +381,24 @@ class RocmPlatform(Platform): parallel_config = vllm_config.parallel_config is_eager_execution = compilation_config == CUDAGraphMode.NONE + if compilation_config.cudagraph_mode.has_full_cudagraphs(): + # decode context parallel does not support full cudagraphs + if parallel_config.decode_context_parallel_size > 1: + logger.warning_once( + "Decode context parallel (DCP) is enabled, which is " + "incompatible with full CUDA graphs. " + "Overriding cudagraph_mode to PIECEWISE." + ) + compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE + # prefill context parallel do not support full cudagraphs + elif parallel_config.prefill_context_parallel_size > 1: + logger.warning_once( + "Prefill context parallel (PCP) is enabled, which is " + "incompatible with full CUDA graphs. " + "Overriding cudagraph_mode to PIECEWISE." + ) + compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE + use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled() if cache_config and cache_config.block_size is None: