From d2a30a2d933226d3951ad98cb5de0c74e2e64826 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Thu, 18 Sep 2025 15:38:37 -0400 Subject: [PATCH] [Bug] Fix torch Compilation Cache Hit Error (#25093) Signed-off-by: yewentao256 --- vllm/config/compilation.py | 12 ------------ vllm/platforms/cuda.py | 17 ++++++++++------- 2 files changed, 10 insertions(+), 19 deletions(-) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index f8ccc2022261..3618f472e742 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -563,18 +563,6 @@ class CompilationConfig: self.cudagraph_mode = CUDAGraphMode.FULL self.splitting_ops = [] - if envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput": - # exclude MoE dispatch/combine from capture by ensuring - # piecewise splitting includes them, so communication remains - # outside CUDA graphs while compute can still be graphed. - moe_ops = [ - "vllm.moe_forward", - "vllm.moe_forward_shared", - ] - for op in moe_ops: - if op not in self.splitting_ops: - self.splitting_ops.append(op) - def splitting_ops_contain_attention(self) -> bool: return self.splitting_ops is not None and all( op in self.splitting_ops for op in self._attention_ops) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 8e3436a9e73c..87d8f2b7481b 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -191,14 +191,17 @@ class CudaPlatformBase(Platform): compilation_config = vllm_config.compilation_config if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" and parallel_config.data_parallel_size > 1 - and compilation_config.cudagraph_mode - not in [CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE]): + and compilation_config.cudagraph_mode != CUDAGraphMode.NONE): + # TODO: Piecewise Cuda graph might be enabled + # if torch compile cache key issue fixed + # See https://github.com/vllm-project/vllm/pull/25093 logger.info( - "Data Parallel with DeepEP high-throughput: using PIECEWISE " - "CUDA graphs and excluding MoE ops from capture. Set " - "VLLM_ALL2ALL_BACKEND=deepep_low_latency if you need MoE " - "graphs captured as well.") - compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE + "Data Parallel: disabling cudagraphs since DP " + "with DeepEP high-throughput kernels are not CUDA Graph " + "compatible. The DeepEP low-latency kernels are CUDA Graph " + "compatible. Set the all_to_all backend to deepep_low_latency " + "to use those kernels instead.") + compilation_config.cudagraph_mode = CUDAGraphMode.NONE @classmethod def get_current_memory_usage(cls,