Remove redundant mutates_args and dispatch_key for direct_register_custom_op (#25512)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-09-23 18:48:40 -04:00 committed by GitHub
parent 95bc60e4cb
commit 7361ab379f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 9 additions and 66 deletions

View File

@ -575,9 +575,7 @@ def unified_attention_fake(
direct_register_custom_op(
op_name="unified_attention",
op_func=unified_attention,
mutates_args=[],
fake_impl=unified_attention_fake,
dispatch_key=current_platform.dispatch_key,
tags=tag_cudagraph_unsafe,
)
@ -628,6 +626,5 @@ direct_register_custom_op(
op_func=unified_attention_with_output,
mutates_args=["output", "output_block_scale"],
fake_impl=unified_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
tags=tag_cudagraph_unsafe,
)

View File

@ -547,7 +547,6 @@ if flashinfer_comm is not None:
"scale_out",
],
fake_impl=call_trtllm_fused_allreduce_norm_fake,
dispatch_key=current_platform.dispatch_key,
)
flashinfer_trtllm_fused_allreduce_norm = (
torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default)

View File

@ -46,7 +46,6 @@ def register_nccl_symmetric_ops(pynccl_comm):
direct_register_custom_op(
op_name="all_reduce_symmetric_with_copy",
op_func=all_reduce_symmetric_with_copy_impl,
mutates_args=[],
fake_impl=all_reduce_symmetric_with_copy_fake,
)

View File

@ -149,29 +149,22 @@ def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int,
if supports_custom_op():
from vllm.platforms import current_platform
direct_register_custom_op(
op_name="all_reduce",
op_func=all_reduce,
mutates_args=[],
fake_impl=all_reduce_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="reduce_scatter",
op_func=reduce_scatter,
mutates_args=[],
fake_impl=reduce_scatter_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="all_gather",
op_func=all_gather,
mutates_args=[],
fake_impl=all_gather_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -11,7 +11,6 @@ import torch
from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op
@ -283,7 +282,6 @@ try:
op_func=_lora_expand,
mutates_args=["output_tensor"],
fake_impl=_lora_expand_fake,
dispatch_key=current_platform.dispatch_key,
)
lora_expand = torch.ops.vllm.lora_expand

View File

@ -11,7 +11,6 @@ import torch
from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op
@ -237,7 +236,6 @@ try:
op_func=_lora_shrink,
mutates_args=["output_tensor"],
fake_impl=_lora_shrink_fake,
dispatch_key=current_platform.dispatch_key,
)
lora_shrink = torch.ops.vllm.lora_shrink

View File

@ -92,7 +92,6 @@ def flashinfer_fused_moe_blockscale_fp8_fake(
direct_register_custom_op(
op_name="flashinfer_fused_moe_blockscale_fp8",
op_func=flashinfer_fused_moe_blockscale_fp8,
mutates_args=[],
fake_impl=flashinfer_fused_moe_blockscale_fp8_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)

View File

@ -235,6 +235,5 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
direct_register_custom_op(
op_name="fused_marlin_moe",
op_func=fused_marlin_moe,
mutates_args=[],
fake_impl=fused_marlin_moe_fake,
)

View File

@ -1256,7 +1256,6 @@ def outplace_fused_experts_fake(
direct_register_custom_op(
op_name="outplace_fused_experts",
op_func=outplace_fused_experts,
mutates_args=[],
fake_impl=outplace_fused_experts_fake,
tags=(() if is_torch_equal_or_newer("2.7.0") else
(torch.Tag.needs_fixed_stride_order, )),

View File

@ -2040,7 +2040,6 @@ direct_register_custom_op(
op_func=moe_forward,
mutates_args=["hidden_states"],
fake_impl=moe_forward_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
@ -2071,7 +2070,6 @@ direct_register_custom_op(
op_func=moe_forward_shared,
mutates_args=["hidden_states"],
fake_impl=moe_forward_shared_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ),
)

View File

@ -223,17 +223,13 @@ if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_aiter_asm_moe_tkw1",
op_func=rocm_aiter_asm_moe_tkw1_impl,
mutates_args=[],
fake_impl=rocm_aiter_asm_moe_tkw1_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_fused_moe",
op_func=rocm_aiter_fused_moe_impl,
mutates_args=[],
fake_impl=rocm_aiter_fused_moe_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
@ -241,7 +237,6 @@ if current_platform.is_rocm():
op_func=rocm_aiter_topk_softmax_impl,
mutates_args=["topk_weights", "topk_indices", "token_expert_indices"],
fake_impl=rocm_aiter_topk_softmax_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
@ -249,7 +244,6 @@ if current_platform.is_rocm():
op_func=rocm_aiter_biased_grouped_topk_impl,
mutates_args=["topk_weights", "topk_ids"],
fake_impl=rocm_aiter_biased_grouped_topk_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
@ -257,7 +251,6 @@ if current_platform.is_rocm():
op_func=rocm_aiter_grouped_topk_impl,
mutates_args=["topk_weights", "topk_ids"],
fake_impl=rocm_aiter_grouped_topk_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -103,17 +103,13 @@ if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_aiter_rms_norm",
op_func=rocm_aiter_rms_norm_impl,
mutates_args=[],
fake_impl=rocm_aiter_rms_norm_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl,
mutates_args=[],
fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -31,7 +31,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
@ -401,5 +400,4 @@ direct_register_custom_op(
op_func=linear_attention,
mutates_args=["output"],
fake_impl=linear_attention_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -27,7 +27,6 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
@ -464,5 +463,4 @@ direct_register_custom_op(
op_func=mamba_mixer,
mutates_args=["output"],
fake_impl=mamba_mixer_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -34,7 +34,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import (
LoaderFunction, composed_weight_loader, sharded_weight_loader)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
@ -765,5 +764,4 @@ direct_register_custom_op(
op_func=mamba_mixer2,
mutates_args=["output"],
fake_impl=mamba_mixer2_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -21,7 +21,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.short_conv_attn import (
ShortConvAttentionMetadata)
@ -251,5 +250,4 @@ direct_register_custom_op(
op_func=short_conv,
mutates_args=["output"],
fake_impl=short_conv_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -4,7 +4,6 @@ import logging
import torch
from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils import direct_register_custom_op
from vllm.utils.deep_gemm import fp8_gemm_nt
@ -75,7 +74,5 @@ def w8a8_deepgemm_block_scaled_mm_fake(
direct_register_custom_op(
op_name="w8a8_deepgemm_block_scaled_mm",
op_func=w8a8_deepgemm_block_scaled_mm,
mutates_args=[],
fake_impl=w8a8_deepgemm_block_scaled_mm_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -161,7 +161,6 @@ try:
direct_register_custom_op(
op_name="_fused_mul_mat_gguf",
op_func=_fused_mul_mat_gguf,
mutates_args=[],
fake_impl=_fused_mul_mat_gguf_fake,
)
fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf
@ -273,7 +272,6 @@ try:
direct_register_custom_op(
op_name="_fused_moe_gguf",
op_func=_fused_moe_gguf,
mutates_args=[],
fake_impl=_fused_moe_gguf_fake,
)
fused_moe_gguf = torch.ops.vllm._fused_moe_gguf
@ -319,7 +317,6 @@ try:
direct_register_custom_op(
op_name="_apply_gguf_embedding",
op_func=_apply_gguf_embedding,
mutates_args=[],
fake_impl=_apply_gguf_embedding_fake,
)
apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding

View File

@ -51,9 +51,7 @@ if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_aiter_gemm_w8a8",
op_func=rocm_aiter_gemm_w8a8_impl,
mutates_args=[],
fake_impl=rocm_aiter_gemm_w8a8_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -91,9 +91,7 @@ if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_aiter_gemm_w8a8_blockscale",
op_func=rocm_aiter_gemm_w8a8_blockscale_impl,
mutates_args=[],
fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake,
dispatch_key=current_platform.dispatch_key,
)
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR
and current_platform.is_fp8_fnuz()):
@ -135,7 +133,6 @@ def _w8a8_triton_block_scaled_mm_fake(
direct_register_custom_op(
"w8a8_triton_block_scaled_mm_func",
_w8a8_triton_block_scaled_mm_func,
mutates_args=[],
fake_impl=_w8a8_triton_block_scaled_mm_fake,
dispatch_key="CUDA",
)

View File

@ -113,7 +113,6 @@ try:
direct_register_custom_op(
op_name="dequant_mxfp4",
op_func=_dequant_mxfp4,
mutates_args=[],
fake_impl=_dequant_mxfp4_fake,
)
dequant_mxfp4 = torch.ops.vllm.dequant_mxfp4
@ -124,7 +123,6 @@ try:
direct_register_custom_op(
op_name="quant_dequant_mxfp4",
op_func=_quant_dequant_mxfp4,
mutates_args=[],
fake_impl=_quant_dequant_mxfp4_fake,
)
quant_dequant_mxfp4 = torch.ops.vllm.quant_dequant_mxfp4

View File

@ -218,9 +218,7 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
direct_register_custom_op(
op_name="rocm_per_tensor_w8a8_scaled_mm_impl",
op_func=rocm_per_tensor_w8a8_scaled_mm_impl,
mutates_args=[],
fake_impl=rocm_per_tensor_w8a8_scaled_mm_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -147,5 +147,4 @@ direct_register_custom_op(
op_func=_flashinfer_rotary_embedding,
mutates_args=["query", "key"], # These tensors are modified in-place
fake_impl=_flashinfer_rotary_embedding_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -136,9 +136,7 @@ def rocm_unquantized_gemm(layer: torch.nn.Module,
direct_register_custom_op(
op_name="rocm_unquantized_gemm_impl",
op_func=rocm_unquantized_gemm_impl,
mutates_args=[],
fake_impl=rocm_unquantized_gemm_impl_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -56,7 +56,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv, direct_register_custom_op
@ -141,9 +140,7 @@ def sequence_parallel_chunk_fake(x: torch.Tensor) -> torch.Tensor:
direct_register_custom_op(
op_name="sequence_parallel_chunk",
op_func=sequence_parallel_chunk,
mutates_args=[],
fake_impl=sequence_parallel_chunk_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ),
)

View File

@ -48,7 +48,6 @@ from vllm.model_executor.models.utils import (
is_pp_missing_parameter, make_empty_intermediate_tensors_factory,
make_layers, maybe_prefix)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
@ -490,7 +489,6 @@ direct_register_custom_op(
op_func=plamo2_mamba_mixer,
mutates_args=["output"],
fake_impl=plamo2_mamba_mixer_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -1225,7 +1225,6 @@ direct_register_custom_op(
op_func=gdn_attention,
mutates_args=["output"],
fake_impl=gdn_attention_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -2546,10 +2546,10 @@ vllm_lib = Library("vllm", "FRAGMENT") # noqa
def direct_register_custom_op(
op_name: str,
op_func: Callable,
mutates_args: list[str],
mutates_args: Optional[list[str]] = None,
fake_impl: Optional[Callable] = None,
target_lib: Optional[Library] = None,
dispatch_key: str = "CUDA",
dispatch_key: Optional[str] = None,
tags: tuple[torch.Tag, ...] = (),
):
"""
@ -2577,6 +2577,13 @@ def direct_register_custom_op(
"the required dependencies.")
return
if mutates_args is None:
mutates_args = []
if dispatch_key is None:
from vllm.platforms import current_platform
dispatch_key = current_platform.dispatch_key
import torch.library
if hasattr(torch.library, "infer_schema"):
schema_str = torch.library.infer_schema(op_func,