mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-14 04:57:03 +08:00
[Bugfix] Disable moe inplace for torch >= 2.9 (#26497)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
parent
4069db3f2e
commit
a462331e36
@ -14,7 +14,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceNoOP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache, disable_inplace
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
marlin_make_workspace_new,
|
||||
marlin_moe_intermediate_size,
|
||||
@ -235,7 +235,11 @@ def fused_marlin_moe(
|
||||
).view(-1, topk, K)
|
||||
|
||||
if output is None:
|
||||
output = hidden_states if inplace else torch.empty_like(hidden_states)
|
||||
if inplace and not disable_inplace():
|
||||
output = hidden_states
|
||||
else:
|
||||
output = torch.empty_like(hidden_states)
|
||||
|
||||
return torch.sum(intermediate_cache3.view(-1, topk, K), dim=1, out=output)
|
||||
|
||||
|
||||
|
||||
@ -39,6 +39,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache,
|
||||
activation_without_mul,
|
||||
disable_inplace,
|
||||
moe_kernel_quantize_input,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4
|
||||
@ -1516,7 +1517,7 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:
|
||||
|
||||
|
||||
def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
|
||||
if inplace:
|
||||
if inplace and not disable_inplace():
|
||||
return torch_vllm_inplace_fused_experts
|
||||
return torch_vllm_outplace_fused_experts
|
||||
|
||||
@ -1766,7 +1767,10 @@ def fused_experts_impl(
|
||||
else:
|
||||
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
|
||||
|
||||
out_hidden_states = hidden_states if inplace else torch.empty_like(hidden_states)
|
||||
if inplace and not disable_inplace():
|
||||
out_hidden_states = hidden_states
|
||||
else:
|
||||
out_hidden_states = torch.empty_like(hidden_states)
|
||||
|
||||
if ocp_mx_scheme is not None:
|
||||
# TODO: On platforms for which `current_platform.supports_mx()` is True
|
||||
|
||||
@ -13,6 +13,7 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache,
|
||||
count_expert_num_tokens,
|
||||
disable_inplace,
|
||||
)
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.worker.ubatching import (
|
||||
@ -1139,7 +1140,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
|
||||
if inplace and self.shared_experts is None:
|
||||
if inplace and self.shared_experts is None and not disable_inplace():
|
||||
output = hidden_states
|
||||
else:
|
||||
output = torch.zeros_like(hidden_states)
|
||||
|
||||
@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
|
||||
mxfp8_e4m3_quantize,
|
||||
)
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import cdiv
|
||||
from vllm.utils import cdiv, is_torch_equal_or_newer
|
||||
from vllm.utils.flashinfer import flashinfer_fp4_quantize
|
||||
|
||||
|
||||
@ -321,3 +321,10 @@ def _validate_scale_shape(
|
||||
|
||||
def activation_without_mul(activation: str) -> str:
|
||||
return activation + "_no_mul"
|
||||
|
||||
|
||||
# Torch custom ops can't deal with outputs aliasing inputs so we need to
|
||||
# disable inplace for torch >= 2.9.
|
||||
# See https://github.com/vllm-project/vllm/issues/26378
|
||||
def disable_inplace() -> bool:
|
||||
return is_torch_equal_or_newer("2.9")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user