Remove redundant mutates_args and dispatch_key for direct_register_custom_op (#25512)

Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Michael Goin 2025-09-23 18:48:40 -04:00 committed by yewentao256
parent eb1f43bc82
commit 907bbca7b7
28 changed files with 9 additions and 66 deletions

View File

@ -575,9 +575,7 @@ def unified_attention_fake(
direct_register_custom_op( direct_register_custom_op(
op_name="unified_attention", op_name="unified_attention",
op_func=unified_attention, op_func=unified_attention,
mutates_args=[],
fake_impl=unified_attention_fake, fake_impl=unified_attention_fake,
dispatch_key=current_platform.dispatch_key,
tags=tag_cudagraph_unsafe, tags=tag_cudagraph_unsafe,
) )
@ -628,6 +626,5 @@ direct_register_custom_op(
op_func=unified_attention_with_output, op_func=unified_attention_with_output,
mutates_args=["output", "output_block_scale"], mutates_args=["output", "output_block_scale"],
fake_impl=unified_attention_with_output_fake, fake_impl=unified_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
tags=tag_cudagraph_unsafe, tags=tag_cudagraph_unsafe,
) )

View File

@ -547,7 +547,6 @@ if flashinfer_comm is not None:
"scale_out", "scale_out",
], ],
fake_impl=call_trtllm_fused_allreduce_norm_fake, fake_impl=call_trtllm_fused_allreduce_norm_fake,
dispatch_key=current_platform.dispatch_key,
) )
flashinfer_trtllm_fused_allreduce_norm = ( flashinfer_trtllm_fused_allreduce_norm = (
torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default) torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default)

View File

@ -46,7 +46,6 @@ def register_nccl_symmetric_ops(pynccl_comm):
direct_register_custom_op( direct_register_custom_op(
op_name="all_reduce_symmetric_with_copy", op_name="all_reduce_symmetric_with_copy",
op_func=all_reduce_symmetric_with_copy_impl, op_func=all_reduce_symmetric_with_copy_impl,
mutates_args=[],
fake_impl=all_reduce_symmetric_with_copy_fake, fake_impl=all_reduce_symmetric_with_copy_fake,
) )

View File

@ -149,29 +149,22 @@ def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int,
if supports_custom_op(): if supports_custom_op():
from vllm.platforms import current_platform
direct_register_custom_op( direct_register_custom_op(
op_name="all_reduce", op_name="all_reduce",
op_func=all_reduce, op_func=all_reduce,
mutates_args=[],
fake_impl=all_reduce_fake, fake_impl=all_reduce_fake,
dispatch_key=current_platform.dispatch_key,
) )
direct_register_custom_op( direct_register_custom_op(
op_name="reduce_scatter", op_name="reduce_scatter",
op_func=reduce_scatter, op_func=reduce_scatter,
mutates_args=[],
fake_impl=reduce_scatter_fake, fake_impl=reduce_scatter_fake,
dispatch_key=current_platform.dispatch_key,
) )
direct_register_custom_op( direct_register_custom_op(
op_name="all_gather", op_name="all_gather",
op_func=all_gather, op_func=all_gather,
mutates_args=[],
fake_impl=all_gather_fake, fake_impl=all_gather_fake,
dispatch_key=current_platform.dispatch_key,
) )

View File

@ -11,7 +11,6 @@ import torch
from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
@ -283,7 +282,6 @@ try:
op_func=_lora_expand, op_func=_lora_expand,
mutates_args=["output_tensor"], mutates_args=["output_tensor"],
fake_impl=_lora_expand_fake, fake_impl=_lora_expand_fake,
dispatch_key=current_platform.dispatch_key,
) )
lora_expand = torch.ops.vllm.lora_expand lora_expand = torch.ops.vllm.lora_expand

View File

@ -11,7 +11,6 @@ import torch
from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
@ -237,7 +236,6 @@ try:
op_func=_lora_shrink, op_func=_lora_shrink,
mutates_args=["output_tensor"], mutates_args=["output_tensor"],
fake_impl=_lora_shrink_fake, fake_impl=_lora_shrink_fake,
dispatch_key=current_platform.dispatch_key,
) )
lora_shrink = torch.ops.vllm.lora_shrink lora_shrink = torch.ops.vllm.lora_shrink

View File

@ -92,7 +92,6 @@ def flashinfer_fused_moe_blockscale_fp8_fake(
direct_register_custom_op( direct_register_custom_op(
op_name="flashinfer_fused_moe_blockscale_fp8", op_name="flashinfer_fused_moe_blockscale_fp8",
op_func=flashinfer_fused_moe_blockscale_fp8, op_func=flashinfer_fused_moe_blockscale_fp8,
mutates_args=[],
fake_impl=flashinfer_fused_moe_blockscale_fp8_fake, fake_impl=flashinfer_fused_moe_blockscale_fp8_fake,
tags=(torch.Tag.needs_fixed_stride_order, ), tags=(torch.Tag.needs_fixed_stride_order, ),
) )

View File

@ -235,6 +235,5 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
direct_register_custom_op( direct_register_custom_op(
op_name="fused_marlin_moe", op_name="fused_marlin_moe",
op_func=fused_marlin_moe, op_func=fused_marlin_moe,
mutates_args=[],
fake_impl=fused_marlin_moe_fake, fake_impl=fused_marlin_moe_fake,
) )

View File

@ -1256,7 +1256,6 @@ def outplace_fused_experts_fake(
direct_register_custom_op( direct_register_custom_op(
op_name="outplace_fused_experts", op_name="outplace_fused_experts",
op_func=outplace_fused_experts, op_func=outplace_fused_experts,
mutates_args=[],
fake_impl=outplace_fused_experts_fake, fake_impl=outplace_fused_experts_fake,
tags=(() if is_torch_equal_or_newer("2.7.0") else tags=(() if is_torch_equal_or_newer("2.7.0") else
(torch.Tag.needs_fixed_stride_order, )), (torch.Tag.needs_fixed_stride_order, )),

View File

@ -2040,7 +2040,6 @@ direct_register_custom_op(
op_func=moe_forward, op_func=moe_forward,
mutates_args=["hidden_states"], mutates_args=["hidden_states"],
fake_impl=moe_forward_fake, fake_impl=moe_forward_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ), tags=(torch.Tag.needs_fixed_stride_order, ),
) )
@ -2071,7 +2070,6 @@ direct_register_custom_op(
op_func=moe_forward_shared, op_func=moe_forward_shared,
mutates_args=["hidden_states"], mutates_args=["hidden_states"],
fake_impl=moe_forward_shared_fake, fake_impl=moe_forward_shared_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ), tags=(torch.Tag.needs_fixed_stride_order, ),
) )

View File

@ -223,17 +223,13 @@ if current_platform.is_rocm():
direct_register_custom_op( direct_register_custom_op(
op_name="rocm_aiter_asm_moe_tkw1", op_name="rocm_aiter_asm_moe_tkw1",
op_func=rocm_aiter_asm_moe_tkw1_impl, op_func=rocm_aiter_asm_moe_tkw1_impl,
mutates_args=[],
fake_impl=rocm_aiter_asm_moe_tkw1_fake, fake_impl=rocm_aiter_asm_moe_tkw1_fake,
dispatch_key=current_platform.dispatch_key,
) )
direct_register_custom_op( direct_register_custom_op(
op_name="rocm_aiter_fused_moe", op_name="rocm_aiter_fused_moe",
op_func=rocm_aiter_fused_moe_impl, op_func=rocm_aiter_fused_moe_impl,
mutates_args=[],
fake_impl=rocm_aiter_fused_moe_fake, fake_impl=rocm_aiter_fused_moe_fake,
dispatch_key=current_platform.dispatch_key,
) )
direct_register_custom_op( direct_register_custom_op(
@ -241,7 +237,6 @@ if current_platform.is_rocm():
op_func=rocm_aiter_topk_softmax_impl, op_func=rocm_aiter_topk_softmax_impl,
mutates_args=["topk_weights", "topk_indices", "token_expert_indices"], mutates_args=["topk_weights", "topk_indices", "token_expert_indices"],
fake_impl=rocm_aiter_topk_softmax_fake, fake_impl=rocm_aiter_topk_softmax_fake,
dispatch_key=current_platform.dispatch_key,
) )
direct_register_custom_op( direct_register_custom_op(
@ -249,7 +244,6 @@ if current_platform.is_rocm():
op_func=rocm_aiter_biased_grouped_topk_impl, op_func=rocm_aiter_biased_grouped_topk_impl,
mutates_args=["topk_weights", "topk_ids"], mutates_args=["topk_weights", "topk_ids"],
fake_impl=rocm_aiter_biased_grouped_topk_fake, fake_impl=rocm_aiter_biased_grouped_topk_fake,
dispatch_key=current_platform.dispatch_key,
) )
direct_register_custom_op( direct_register_custom_op(
@ -257,7 +251,6 @@ if current_platform.is_rocm():
op_func=rocm_aiter_grouped_topk_impl, op_func=rocm_aiter_grouped_topk_impl,
mutates_args=["topk_weights", "topk_ids"], mutates_args=["topk_weights", "topk_ids"],
fake_impl=rocm_aiter_grouped_topk_fake, fake_impl=rocm_aiter_grouped_topk_fake,
dispatch_key=current_platform.dispatch_key,
) )

View File

@ -103,17 +103,13 @@ if current_platform.is_rocm():
direct_register_custom_op( direct_register_custom_op(
op_name="rocm_aiter_rms_norm", op_name="rocm_aiter_rms_norm",
op_func=rocm_aiter_rms_norm_impl, op_func=rocm_aiter_rms_norm_impl,
mutates_args=[],
fake_impl=rocm_aiter_rms_norm_fake, fake_impl=rocm_aiter_rms_norm_fake,
dispatch_key=current_platform.dispatch_key,
) )
direct_register_custom_op( direct_register_custom_op(
op_name="rocm_aiter_rmsnorm2d_fwd_with_add", op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl, op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl,
mutates_args=[],
fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake, fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake,
dispatch_key=current_platform.dispatch_key,
) )

View File

@ -31,7 +31,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator) MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
@ -401,5 +400,4 @@ direct_register_custom_op(
op_func=linear_attention, op_func=linear_attention,
mutates_args=["output"], mutates_args=["output"],
fake_impl=linear_attention_fake, fake_impl=linear_attention_fake,
dispatch_key=current_platform.dispatch_key,
) )

View File

@ -27,7 +27,6 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update) selective_scan_fn, selective_state_update)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
@ -464,5 +463,4 @@ direct_register_custom_op(
op_func=mamba_mixer, op_func=mamba_mixer,
mutates_args=["output"], mutates_args=["output"],
fake_impl=mamba_mixer_fake, fake_impl=mamba_mixer_fake,
dispatch_key=current_platform.dispatch_key,
) )

View File

@ -34,7 +34,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
LoaderFunction, composed_weight_loader, sharded_weight_loader) LoaderFunction, composed_weight_loader, sharded_weight_loader)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
@ -765,5 +764,4 @@ direct_register_custom_op(
op_func=mamba_mixer2, op_func=mamba_mixer2,
mutates_args=["output"], mutates_args=["output"],
fake_impl=mamba_mixer2_fake, fake_impl=mamba_mixer2_fake,
dispatch_key=current_platform.dispatch_key,
) )

View File

@ -21,7 +21,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator) MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update) causal_conv1d_fn, causal_conv1d_update)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.short_conv_attn import ( from vllm.v1.attention.backends.short_conv_attn import (
ShortConvAttentionMetadata) ShortConvAttentionMetadata)
@ -251,5 +250,4 @@ direct_register_custom_op(
op_func=short_conv, op_func=short_conv,
mutates_args=["output"], mutates_args=["output"],
fake_impl=short_conv_fake, fake_impl=short_conv_fake,
dispatch_key=current_platform.dispatch_key,
) )

View File

@ -4,7 +4,6 @@ import logging
import torch import torch
from vllm.platforms import current_platform
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from vllm.utils.deep_gemm import fp8_gemm_nt from vllm.utils.deep_gemm import fp8_gemm_nt
@ -75,7 +74,5 @@ def w8a8_deepgemm_block_scaled_mm_fake(
direct_register_custom_op( direct_register_custom_op(
op_name="w8a8_deepgemm_block_scaled_mm", op_name="w8a8_deepgemm_block_scaled_mm",
op_func=w8a8_deepgemm_block_scaled_mm, op_func=w8a8_deepgemm_block_scaled_mm,
mutates_args=[],
fake_impl=w8a8_deepgemm_block_scaled_mm_fake, fake_impl=w8a8_deepgemm_block_scaled_mm_fake,
dispatch_key=current_platform.dispatch_key,
) )

View File

@ -161,7 +161,6 @@ try:
direct_register_custom_op( direct_register_custom_op(
op_name="_fused_mul_mat_gguf", op_name="_fused_mul_mat_gguf",
op_func=_fused_mul_mat_gguf, op_func=_fused_mul_mat_gguf,
mutates_args=[],
fake_impl=_fused_mul_mat_gguf_fake, fake_impl=_fused_mul_mat_gguf_fake,
) )
fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf
@ -273,7 +272,6 @@ try:
direct_register_custom_op( direct_register_custom_op(
op_name="_fused_moe_gguf", op_name="_fused_moe_gguf",
op_func=_fused_moe_gguf, op_func=_fused_moe_gguf,
mutates_args=[],
fake_impl=_fused_moe_gguf_fake, fake_impl=_fused_moe_gguf_fake,
) )
fused_moe_gguf = torch.ops.vllm._fused_moe_gguf fused_moe_gguf = torch.ops.vllm._fused_moe_gguf
@ -319,7 +317,6 @@ try:
direct_register_custom_op( direct_register_custom_op(
op_name="_apply_gguf_embedding", op_name="_apply_gguf_embedding",
op_func=_apply_gguf_embedding, op_func=_apply_gguf_embedding,
mutates_args=[],
fake_impl=_apply_gguf_embedding_fake, fake_impl=_apply_gguf_embedding_fake,
) )
apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding

View File

@ -51,9 +51,7 @@ if current_platform.is_rocm():
direct_register_custom_op( direct_register_custom_op(
op_name="rocm_aiter_gemm_w8a8", op_name="rocm_aiter_gemm_w8a8",
op_func=rocm_aiter_gemm_w8a8_impl, op_func=rocm_aiter_gemm_w8a8_impl,
mutates_args=[],
fake_impl=rocm_aiter_gemm_w8a8_fake, fake_impl=rocm_aiter_gemm_w8a8_fake,
dispatch_key=current_platform.dispatch_key,
) )

View File

@ -91,9 +91,7 @@ if current_platform.is_rocm():
direct_register_custom_op( direct_register_custom_op(
op_name="rocm_aiter_gemm_w8a8_blockscale", op_name="rocm_aiter_gemm_w8a8_blockscale",
op_func=rocm_aiter_gemm_w8a8_blockscale_impl, op_func=rocm_aiter_gemm_w8a8_blockscale_impl,
mutates_args=[],
fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake, fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake,
dispatch_key=current_platform.dispatch_key,
) )
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR
and current_platform.is_fp8_fnuz()): and current_platform.is_fp8_fnuz()):
@ -135,7 +133,6 @@ def _w8a8_triton_block_scaled_mm_fake(
direct_register_custom_op( direct_register_custom_op(
"w8a8_triton_block_scaled_mm_func", "w8a8_triton_block_scaled_mm_func",
_w8a8_triton_block_scaled_mm_func, _w8a8_triton_block_scaled_mm_func,
mutates_args=[],
fake_impl=_w8a8_triton_block_scaled_mm_fake, fake_impl=_w8a8_triton_block_scaled_mm_fake,
dispatch_key="CUDA", dispatch_key="CUDA",
) )

View File

@ -113,7 +113,6 @@ try:
direct_register_custom_op( direct_register_custom_op(
op_name="dequant_mxfp4", op_name="dequant_mxfp4",
op_func=_dequant_mxfp4, op_func=_dequant_mxfp4,
mutates_args=[],
fake_impl=_dequant_mxfp4_fake, fake_impl=_dequant_mxfp4_fake,
) )
dequant_mxfp4 = torch.ops.vllm.dequant_mxfp4 dequant_mxfp4 = torch.ops.vllm.dequant_mxfp4
@ -124,7 +123,6 @@ try:
direct_register_custom_op( direct_register_custom_op(
op_name="quant_dequant_mxfp4", op_name="quant_dequant_mxfp4",
op_func=_quant_dequant_mxfp4, op_func=_quant_dequant_mxfp4,
mutates_args=[],
fake_impl=_quant_dequant_mxfp4_fake, fake_impl=_quant_dequant_mxfp4_fake,
) )
quant_dequant_mxfp4 = torch.ops.vllm.quant_dequant_mxfp4 quant_dequant_mxfp4 = torch.ops.vllm.quant_dequant_mxfp4

View File

@ -218,9 +218,7 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
direct_register_custom_op( direct_register_custom_op(
op_name="rocm_per_tensor_w8a8_scaled_mm_impl", op_name="rocm_per_tensor_w8a8_scaled_mm_impl",
op_func=rocm_per_tensor_w8a8_scaled_mm_impl, op_func=rocm_per_tensor_w8a8_scaled_mm_impl,
mutates_args=[],
fake_impl=rocm_per_tensor_w8a8_scaled_mm_fake, fake_impl=rocm_per_tensor_w8a8_scaled_mm_fake,
dispatch_key=current_platform.dispatch_key,
) )

View File

@ -147,5 +147,4 @@ direct_register_custom_op(
op_func=_flashinfer_rotary_embedding, op_func=_flashinfer_rotary_embedding,
mutates_args=["query", "key"], # These tensors are modified in-place mutates_args=["query", "key"], # These tensors are modified in-place
fake_impl=_flashinfer_rotary_embedding_fake, fake_impl=_flashinfer_rotary_embedding_fake,
dispatch_key=current_platform.dispatch_key,
) )

View File

@ -136,9 +136,7 @@ def rocm_unquantized_gemm(layer: torch.nn.Module,
direct_register_custom_op( direct_register_custom_op(
op_name="rocm_unquantized_gemm_impl", op_name="rocm_unquantized_gemm_impl",
op_func=rocm_unquantized_gemm_impl, op_func=rocm_unquantized_gemm_impl,
mutates_args=[],
fake_impl=rocm_unquantized_gemm_impl_fake, fake_impl=rocm_unquantized_gemm_impl_fake,
dispatch_key=current_platform.dispatch_key,
) )

View File

@ -56,7 +56,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv, direct_register_custom_op from vllm.utils import cdiv, direct_register_custom_op
@ -141,9 +140,7 @@ def sequence_parallel_chunk_fake(x: torch.Tensor) -> torch.Tensor:
direct_register_custom_op( direct_register_custom_op(
op_name="sequence_parallel_chunk", op_name="sequence_parallel_chunk",
op_func=sequence_parallel_chunk, op_func=sequence_parallel_chunk,
mutates_args=[],
fake_impl=sequence_parallel_chunk_fake, fake_impl=sequence_parallel_chunk_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ), tags=(torch.Tag.needs_fixed_stride_order, ),
) )

View File

@ -48,7 +48,6 @@ from vllm.model_executor.models.utils import (
is_pp_missing_parameter, make_empty_intermediate_tensors_factory, is_pp_missing_parameter, make_empty_intermediate_tensors_factory,
make_layers, maybe_prefix) make_layers, maybe_prefix)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
@ -490,7 +489,6 @@ direct_register_custom_op(
op_func=plamo2_mamba_mixer, op_func=plamo2_mamba_mixer,
mutates_args=["output"], mutates_args=["output"],
fake_impl=plamo2_mamba_mixer_fake, fake_impl=plamo2_mamba_mixer_fake,
dispatch_key=current_platform.dispatch_key,
) )

View File

@ -1225,7 +1225,6 @@ direct_register_custom_op(
op_func=gdn_attention, op_func=gdn_attention,
mutates_args=["output"], mutates_args=["output"],
fake_impl=gdn_attention_fake, fake_impl=gdn_attention_fake,
dispatch_key=current_platform.dispatch_key,
) )

View File

@ -2546,10 +2546,10 @@ vllm_lib = Library("vllm", "FRAGMENT") # noqa
def direct_register_custom_op( def direct_register_custom_op(
op_name: str, op_name: str,
op_func: Callable, op_func: Callable,
mutates_args: list[str], mutates_args: Optional[list[str]] = None,
fake_impl: Optional[Callable] = None, fake_impl: Optional[Callable] = None,
target_lib: Optional[Library] = None, target_lib: Optional[Library] = None,
dispatch_key: str = "CUDA", dispatch_key: Optional[str] = None,
tags: tuple[torch.Tag, ...] = (), tags: tuple[torch.Tag, ...] = (),
): ):
""" """
@ -2577,6 +2577,13 @@ def direct_register_custom_op(
"the required dependencies.") "the required dependencies.")
return return
if mutates_args is None:
mutates_args = []
if dispatch_key is None:
from vllm.platforms import current_platform
dispatch_key = current_platform.dispatch_key
import torch.library import torch.library
if hasattr(torch.library, "infer_schema"): if hasattr(torch.library, "infer_schema"):
schema_str = torch.library.infer_schema(op_func, schema_str = torch.library.infer_schema(op_func,