From 08405609cc14d81f20e9faf61b2cd87e7909b797 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Thu, 16 Oct 2025 20:08:47 -0700 Subject: [PATCH] disable graph partition in custom op (#26952) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Boyuan Feng Signed-off-by: Boyuan Feng Co-authored-by: Luka Govedič --- vllm/model_executor/layers/fused_moe/fused_moe.py | 7 ++++++- vllm/model_executor/utils.py | 9 +++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 42724f2ff3c0..69e32438e5b2 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -49,6 +49,7 @@ from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4 from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6 from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme +from vllm.model_executor.utils import maybe_disable_graph_partition from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer @@ -1145,7 +1146,11 @@ def fused_topk_bias( # This is used by the Deepseek-V2 and Deepseek-V3 model -@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) +@torch.compile( + dynamic=True, + backend=current_platform.simple_compile_backend, + options=maybe_disable_graph_partition(current_platform.simple_compile_backend), +) def grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index 38cd230082f8..5ffee6cb8d8b 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -7,6 +7,8 @@ from typing import Any import torch +from vllm.utils import is_torch_equal_or_newer + def set_random_seed(seed: int) -> None: from vllm.platforms import current_platform @@ -83,3 +85,10 @@ def get_moe_expert_mapping( if child_map is not None: return child_map() return [] + + +def maybe_disable_graph_partition(current_backend: str) -> dict[str, bool]: + if current_backend == "inductor" and is_torch_equal_or_newer("2.9.0.dev"): + return {"graph_partition": False} + else: + return {}