diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 677fb069bc07..09600e96a1c6 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -546,7 +546,8 @@ class CompilationConfig: # full cudagraph outside the fx graph. This reduces some cpu # overhead when the runtime batch_size is not cudagraph captured. # see https://github.com/vllm-project/vllm/pull/20059 for details. - self.splitting_ops = self._attention_ops + # make a copy to avoid mutating the class-level list via reference. + self.splitting_ops = list(self._attention_ops) elif len(self.splitting_ops) == 0: logger.warning_once("Using piecewise compilation with empty " "splitting_ops.") @@ -561,6 +562,18 @@ class CompilationConfig: self.cudagraph_mode = CUDAGraphMode.FULL self.splitting_ops = [] + if envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput": + # exclude MoE dispatch/combine from capture by ensuring + # piecewise splitting includes them, so communication remains + # outside CUDA graphs while compute can still be graphed. + moe_ops = [ + "vllm.moe_forward", + "vllm.moe_forward_shared", + ] + for op in moe_ops: + if op not in self.splitting_ops: + self.splitting_ops.append(op) + def splitting_ops_contain_attention(self) -> bool: return self.splitting_ops is not None and all( op in self.splitting_ops for op in self._attention_ops) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 1b0a298352cb..fc1a399d6f43 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -183,16 +183,14 @@ class CudaPlatformBase(Platform): compilation_config = vllm_config.compilation_config if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" and parallel_config.data_parallel_size > 1 - and compilation_config.cudagraph_mode != CUDAGraphMode.NONE): + and compilation_config.cudagraph_mode + not in [CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE]): logger.info( - "Data Parallel: disabling cudagraphs since DP " - "with DeepEP high-throughput kernels are not CUDA Graph " - "compatible. The DeepEP low-latency kernels are CUDA Graph " - "compatible. Set the all_to_all backend to deepep_low_latency " - "to use those kernels instead.") - compilation_config.cudagraph_mode = CUDAGraphMode.NONE - if model_config is not None: - model_config.enforce_eager = True + "Data Parallel with DeepEP high-throughput: using PIECEWISE " + "CUDA graphs and excluding MoE ops from capture. Set " + "VLLM_ALL2ALL_BACKEND=deepep_low_latency if you need MoE " + "graphs captured as well.") + compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE @classmethod def get_current_memory_usage(cls,