From a462331e367191d61414878072dcee8f23bb9214 Mon Sep 17 00:00:00 2001 From: bnellnm <49004751+bnellnm@users.noreply.github.com> Date: Thu, 9 Oct 2025 14:07:38 -0400 Subject: [PATCH] [Bugfix] Disable moe inplace for torch >= 2.9 (#26497) Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/fused_marlin_moe.py | 8 ++++++-- vllm/model_executor/layers/fused_moe/fused_moe.py | 8 ++++++-- vllm/model_executor/layers/fused_moe/modular_kernel.py | 3 ++- vllm/model_executor/layers/fused_moe/utils.py | 9 ++++++++- 4 files changed, 22 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index b0cc83fd2e450..6412c3eaa1932 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -14,7 +14,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP, ) -from vllm.model_executor.layers.fused_moe.utils import _resize_cache +from vllm.model_executor.layers.fused_moe.utils import _resize_cache, disable_inplace from vllm.model_executor.layers.quantization.utils.marlin_utils import ( marlin_make_workspace_new, marlin_moe_intermediate_size, @@ -235,7 +235,11 @@ def fused_marlin_moe( ).view(-1, topk, K) if output is None: - output = hidden_states if inplace else torch.empty_like(hidden_states) + if inplace and not disable_inplace(): + output = hidden_states + else: + output = torch.empty_like(hidden_states) + return torch.sum(intermediate_cache3.view(-1, topk, K), dim=1, out=output) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index da7c4a3c55893..eda825ffcae1e 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -39,6 +39,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, activation_without_mul, + disable_inplace, moe_kernel_quantize_input, ) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4 @@ -1516,7 +1517,7 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor: def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]: - if inplace: + if inplace and not disable_inplace(): return torch_vllm_inplace_fused_experts return torch_vllm_outplace_fused_experts @@ -1766,7 +1767,10 @@ def fused_experts_impl( else: raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") - out_hidden_states = hidden_states if inplace else torch.empty_like(hidden_states) + if inplace and not disable_inplace(): + out_hidden_states = hidden_states + else: + out_hidden_states = torch.empty_like(hidden_states) if ocp_mx_scheme is not None: # TODO: On platforms for which `current_platform.supports_mx()` is True diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 62162b6cbae10..19e71f917eeed 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -13,6 +13,7 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, count_expert_num_tokens, + disable_inplace, ) from vllm.utils import cdiv from vllm.v1.worker.ubatching import ( @@ -1139,7 +1140,7 @@ class FusedMoEModularKernel(torch.nn.Module): - torch.Tensor: The output tensor after applying the MoE layer. """ - if inplace and self.shared_experts is None: + if inplace and self.shared_experts is None and not disable_inplace(): output = hidden_states else: output = torch.zeros_like(hidden_states) diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index dddf788b62e27..bd68d2ec884de 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( mxfp8_e4m3_quantize, ) from vllm.triton_utils import tl, triton -from vllm.utils import cdiv +from vllm.utils import cdiv, is_torch_equal_or_newer from vllm.utils.flashinfer import flashinfer_fp4_quantize @@ -321,3 +321,10 @@ def _validate_scale_shape( def activation_without_mul(activation: str) -> str: return activation + "_no_mul" + + +# Torch custom ops can't deal with outputs aliasing inputs so we need to +# disable inplace for torch >= 2.9. +# See https://github.com/vllm-project/vllm/issues/26378 +def disable_inplace() -> bool: + return is_torch_equal_or_newer("2.9")