From b2eb2b5ad7090ad3b3e002b200104a82eeb2fa7f Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Fri, 18 Jul 2025 14:10:21 -0400 Subject: [PATCH] [Kernel] Apply torch.Tag.needs_fixed_stride_order only for torch==2.6.0 (#19346) Signed-off-by: rzou --- csrc/torch_bindings.cpp | 12 ++++++++---- vllm/attention/ops/rocm_aiter_mla.py | 8 ++++++-- vllm/model_executor/layers/fused_moe/fused_moe.py | 8 +++++--- 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 23e9212a2f1d1..79e2575974b52 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -20,13 +20,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // vLLM custom ops // - // The default behavior in PyTorch 2.6 is "requires_contiguous", so we need + // The default behavior in PyTorch 2.6 was changed to "requires_contiguous", + // so we need // to override this for many GEMMs with the following tag. Otherwise, // torch.compile will force all input tensors to be contiguous(), which // will break many custom ops that require column-major weight matrices. - // TODO: remove this for PyTorch 2.8, when the default is planned to switch - // to match exact eager-mode strides. - at::Tag stride_tag = at::Tag::needs_fixed_stride_order; + // This was a bug and PyTorch 2.7 has since fixed this. +#if TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 6 + #define stride_tag at::Tag::needs_fixed_stride_order +#else + #define stride_tag +#endif ops.def("weak_ref_tensor(Tensor input) -> Tensor"); ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor); diff --git a/vllm/attention/ops/rocm_aiter_mla.py b/vllm/attention/ops/rocm_aiter_mla.py index cce6b46394606..d91cda255ff31 100644 --- a/vllm/attention/ops/rocm_aiter_mla.py +++ b/vllm/attention/ops/rocm_aiter_mla.py @@ -6,7 +6,7 @@ from typing import Optional import torch from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer def get_aiter_mla_metadata(max_batch_size: int, block_size: int, @@ -93,8 +93,12 @@ def mla_decode_fwd_fake( if current_platform.is_rocm(): + if is_torch_equal_or_newer("2.7.0"): + tags = () + else: + tags = (torch.Tag.needs_fixed_stride_order, ), direct_register_custom_op(op_name="rocm_aiter_mla_decode_fwd", op_func=mla_decode_fwd_impl, mutates_args=["o"], fake_impl=mla_decode_fwd_fake, - tags=[torch.Tag.needs_fixed_stride_order]) + tags=tags) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 4593602600739..aec5d7b252e39 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -33,7 +33,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( dequant_mxfp4) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton -from vllm.utils import direct_register_custom_op +from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled @@ -1056,7 +1056,8 @@ direct_register_custom_op( op_func=inplace_fused_experts, mutates_args=["hidden_states"], fake_impl=inplace_fused_experts_fake, - tags=(torch.Tag.needs_fixed_stride_order, ), + tags=(() if is_torch_equal_or_newer("2.7.0") else + (torch.Tag.needs_fixed_stride_order, )), ) @@ -1122,7 +1123,8 @@ direct_register_custom_op( op_func=outplace_fused_experts, mutates_args=[], fake_impl=outplace_fused_experts_fake, - tags=(torch.Tag.needs_fixed_stride_order, ), + tags=(() if is_torch_equal_or_newer("2.7.0") else + (torch.Tag.needs_fixed_stride_order, )), )