[Bugfix] Disable moe inplace for torch >= 2.9 (#26497)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm 2025-10-09 14:07:38 -04:00 committed by GitHub
parent 4069db3f2e
commit a462331e36
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 22 additions and 6 deletions

View File

@ -14,7 +14,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.fused_moe.utils import _resize_cache, disable_inplace
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_workspace_new,
marlin_moe_intermediate_size,
@ -235,7 +235,11 @@ def fused_marlin_moe(
).view(-1, topk, K)
if output is None:
output = hidden_states if inplace else torch.empty_like(hidden_states)
if inplace and not disable_inplace():
output = hidden_states
else:
output = torch.empty_like(hidden_states)
return torch.sum(intermediate_cache3.view(-1, topk, K), dim=1, out=output)

View File

@ -39,6 +39,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache,
activation_without_mul,
disable_inplace,
moe_kernel_quantize_input,
)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4
@ -1516,7 +1517,7 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:
def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
if inplace:
if inplace and not disable_inplace():
return torch_vllm_inplace_fused_experts
return torch_vllm_outplace_fused_experts
@ -1766,7 +1767,10 @@ def fused_experts_impl(
else:
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
out_hidden_states = hidden_states if inplace else torch.empty_like(hidden_states)
if inplace and not disable_inplace():
out_hidden_states = hidden_states
else:
out_hidden_states = torch.empty_like(hidden_states)
if ocp_mx_scheme is not None:
# TODO: On platforms for which `current_platform.supports_mx()` is True

View File

@ -13,6 +13,7 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache,
count_expert_num_tokens,
disable_inplace,
)
from vllm.utils import cdiv
from vllm.v1.worker.ubatching import (
@ -1139,7 +1140,7 @@ class FusedMoEModularKernel(torch.nn.Module):
- torch.Tensor: The output tensor after applying the MoE layer.
"""
if inplace and self.shared_experts is None:
if inplace and self.shared_experts is None and not disable_inplace():
output = hidden_states
else:
output = torch.zeros_like(hidden_states)

View File

@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
mxfp8_e4m3_quantize,
)
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv
from vllm.utils import cdiv, is_torch_equal_or_newer
from vllm.utils.flashinfer import flashinfer_fp4_quantize
@ -321,3 +321,10 @@ def _validate_scale_shape(
def activation_without_mul(activation: str) -> str:
return activation + "_no_mul"
# Torch custom ops can't deal with outputs aliasing inputs so we need to
# disable inplace for torch >= 2.9.
# See https://github.com/vllm-project/vllm/issues/26378
def disable_inplace() -> bool:
return is_torch_equal_or_newer("2.9")