mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 11:45:59 +08:00
Remove redundant mutates_args and dispatch_key for direct_register_custom_op (#25512)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
95bc60e4cb
commit
7361ab379f
@ -575,9 +575,7 @@ def unified_attention_fake(
|
||||
direct_register_custom_op(
|
||||
op_name="unified_attention",
|
||||
op_func=unified_attention,
|
||||
mutates_args=[],
|
||||
fake_impl=unified_attention_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
tags=tag_cudagraph_unsafe,
|
||||
)
|
||||
|
||||
@ -628,6 +626,5 @@ direct_register_custom_op(
|
||||
op_func=unified_attention_with_output,
|
||||
mutates_args=["output", "output_block_scale"],
|
||||
fake_impl=unified_attention_with_output_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
tags=tag_cudagraph_unsafe,
|
||||
)
|
||||
|
||||
@ -547,7 +547,6 @@ if flashinfer_comm is not None:
|
||||
"scale_out",
|
||||
],
|
||||
fake_impl=call_trtllm_fused_allreduce_norm_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
flashinfer_trtllm_fused_allreduce_norm = (
|
||||
torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default)
|
||||
|
||||
@ -46,7 +46,6 @@ def register_nccl_symmetric_ops(pynccl_comm):
|
||||
direct_register_custom_op(
|
||||
op_name="all_reduce_symmetric_with_copy",
|
||||
op_func=all_reduce_symmetric_with_copy_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=all_reduce_symmetric_with_copy_fake,
|
||||
)
|
||||
|
||||
|
||||
@ -149,29 +149,22 @@ def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int,
|
||||
|
||||
|
||||
if supports_custom_op():
|
||||
from vllm.platforms import current_platform
|
||||
direct_register_custom_op(
|
||||
op_name="all_reduce",
|
||||
op_func=all_reduce,
|
||||
mutates_args=[],
|
||||
fake_impl=all_reduce_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="reduce_scatter",
|
||||
op_func=reduce_scatter,
|
||||
mutates_args=[],
|
||||
fake_impl=reduce_scatter_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="all_gather",
|
||||
op_func=all_gather,
|
||||
mutates_args=[],
|
||||
fake_impl=all_gather_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -11,7 +11,6 @@ import torch
|
||||
|
||||
from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel
|
||||
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
@ -283,7 +282,6 @@ try:
|
||||
op_func=_lora_expand,
|
||||
mutates_args=["output_tensor"],
|
||||
fake_impl=_lora_expand_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
lora_expand = torch.ops.vllm.lora_expand
|
||||
|
||||
|
||||
@ -11,7 +11,6 @@ import torch
|
||||
|
||||
from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
|
||||
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
@ -237,7 +236,6 @@ try:
|
||||
op_func=_lora_shrink,
|
||||
mutates_args=["output_tensor"],
|
||||
fake_impl=_lora_shrink_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
lora_shrink = torch.ops.vllm.lora_shrink
|
||||
|
||||
|
||||
@ -92,7 +92,6 @@ def flashinfer_fused_moe_blockscale_fp8_fake(
|
||||
direct_register_custom_op(
|
||||
op_name="flashinfer_fused_moe_blockscale_fp8",
|
||||
op_func=flashinfer_fused_moe_blockscale_fp8,
|
||||
mutates_args=[],
|
||||
fake_impl=flashinfer_fused_moe_blockscale_fp8_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||
)
|
||||
|
||||
@ -235,6 +235,5 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
|
||||
direct_register_custom_op(
|
||||
op_name="fused_marlin_moe",
|
||||
op_func=fused_marlin_moe,
|
||||
mutates_args=[],
|
||||
fake_impl=fused_marlin_moe_fake,
|
||||
)
|
||||
|
||||
@ -1256,7 +1256,6 @@ def outplace_fused_experts_fake(
|
||||
direct_register_custom_op(
|
||||
op_name="outplace_fused_experts",
|
||||
op_func=outplace_fused_experts,
|
||||
mutates_args=[],
|
||||
fake_impl=outplace_fused_experts_fake,
|
||||
tags=(() if is_torch_equal_or_newer("2.7.0") else
|
||||
(torch.Tag.needs_fixed_stride_order, )),
|
||||
|
||||
@ -2040,7 +2040,6 @@ direct_register_custom_op(
|
||||
op_func=moe_forward,
|
||||
mutates_args=["hidden_states"],
|
||||
fake_impl=moe_forward_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||
)
|
||||
|
||||
@ -2071,7 +2070,6 @@ direct_register_custom_op(
|
||||
op_func=moe_forward_shared,
|
||||
mutates_args=["hidden_states"],
|
||||
fake_impl=moe_forward_shared_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||
)
|
||||
|
||||
|
||||
@ -223,17 +223,13 @@ if current_platform.is_rocm():
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_asm_moe_tkw1",
|
||||
op_func=rocm_aiter_asm_moe_tkw1_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=rocm_aiter_asm_moe_tkw1_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_fused_moe",
|
||||
op_func=rocm_aiter_fused_moe_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=rocm_aiter_fused_moe_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
@ -241,7 +237,6 @@ if current_platform.is_rocm():
|
||||
op_func=rocm_aiter_topk_softmax_impl,
|
||||
mutates_args=["topk_weights", "topk_indices", "token_expert_indices"],
|
||||
fake_impl=rocm_aiter_topk_softmax_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
@ -249,7 +244,6 @@ if current_platform.is_rocm():
|
||||
op_func=rocm_aiter_biased_grouped_topk_impl,
|
||||
mutates_args=["topk_weights", "topk_ids"],
|
||||
fake_impl=rocm_aiter_biased_grouped_topk_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
@ -257,7 +251,6 @@ if current_platform.is_rocm():
|
||||
op_func=rocm_aiter_grouped_topk_impl,
|
||||
mutates_args=["topk_weights", "topk_ids"],
|
||||
fake_impl=rocm_aiter_grouped_topk_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -103,17 +103,13 @@ if current_platform.is_rocm():
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_rms_norm",
|
||||
op_func=rocm_aiter_rms_norm_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=rocm_aiter_rms_norm_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
|
||||
op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -31,7 +31,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
|
||||
|
||||
@ -401,5 +400,4 @@ direct_register_custom_op(
|
||||
op_func=linear_attention,
|
||||
mutates_args=["output"],
|
||||
fake_impl=linear_attention_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
@ -27,7 +27,6 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_scan_fn, selective_state_update)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
|
||||
|
||||
@ -464,5 +463,4 @@ direct_register_custom_op(
|
||||
op_func=mamba_mixer,
|
||||
mutates_args=["output"],
|
||||
fake_impl=mamba_mixer_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
@ -34,7 +34,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
LoaderFunction, composed_weight_loader, sharded_weight_loader)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
|
||||
|
||||
@ -765,5 +764,4 @@ direct_register_custom_op(
|
||||
op_func=mamba_mixer2,
|
||||
mutates_args=["output"],
|
||||
fake_impl=mamba_mixer2_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
@ -21,7 +21,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.short_conv_attn import (
|
||||
ShortConvAttentionMetadata)
|
||||
@ -251,5 +250,4 @@ direct_register_custom_op(
|
||||
op_func=short_conv,
|
||||
mutates_args=["output"],
|
||||
fake_impl=short_conv_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
@ -4,7 +4,6 @@ import logging
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils.deep_gemm import fp8_gemm_nt
|
||||
@ -75,7 +74,5 @@ def w8a8_deepgemm_block_scaled_mm_fake(
|
||||
direct_register_custom_op(
|
||||
op_name="w8a8_deepgemm_block_scaled_mm",
|
||||
op_func=w8a8_deepgemm_block_scaled_mm,
|
||||
mutates_args=[],
|
||||
fake_impl=w8a8_deepgemm_block_scaled_mm_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
@ -161,7 +161,6 @@ try:
|
||||
direct_register_custom_op(
|
||||
op_name="_fused_mul_mat_gguf",
|
||||
op_func=_fused_mul_mat_gguf,
|
||||
mutates_args=[],
|
||||
fake_impl=_fused_mul_mat_gguf_fake,
|
||||
)
|
||||
fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf
|
||||
@ -273,7 +272,6 @@ try:
|
||||
direct_register_custom_op(
|
||||
op_name="_fused_moe_gguf",
|
||||
op_func=_fused_moe_gguf,
|
||||
mutates_args=[],
|
||||
fake_impl=_fused_moe_gguf_fake,
|
||||
)
|
||||
fused_moe_gguf = torch.ops.vllm._fused_moe_gguf
|
||||
@ -319,7 +317,6 @@ try:
|
||||
direct_register_custom_op(
|
||||
op_name="_apply_gguf_embedding",
|
||||
op_func=_apply_gguf_embedding,
|
||||
mutates_args=[],
|
||||
fake_impl=_apply_gguf_embedding_fake,
|
||||
)
|
||||
apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding
|
||||
|
||||
@ -51,9 +51,7 @@ if current_platform.is_rocm():
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_gemm_w8a8",
|
||||
op_func=rocm_aiter_gemm_w8a8_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=rocm_aiter_gemm_w8a8_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -91,9 +91,7 @@ if current_platform.is_rocm():
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_gemm_w8a8_blockscale",
|
||||
op_func=rocm_aiter_gemm_w8a8_blockscale_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR
|
||||
and current_platform.is_fp8_fnuz()):
|
||||
@ -135,7 +133,6 @@ def _w8a8_triton_block_scaled_mm_fake(
|
||||
direct_register_custom_op(
|
||||
"w8a8_triton_block_scaled_mm_func",
|
||||
_w8a8_triton_block_scaled_mm_func,
|
||||
mutates_args=[],
|
||||
fake_impl=_w8a8_triton_block_scaled_mm_fake,
|
||||
dispatch_key="CUDA",
|
||||
)
|
||||
|
||||
@ -113,7 +113,6 @@ try:
|
||||
direct_register_custom_op(
|
||||
op_name="dequant_mxfp4",
|
||||
op_func=_dequant_mxfp4,
|
||||
mutates_args=[],
|
||||
fake_impl=_dequant_mxfp4_fake,
|
||||
)
|
||||
dequant_mxfp4 = torch.ops.vllm.dequant_mxfp4
|
||||
@ -124,7 +123,6 @@ try:
|
||||
direct_register_custom_op(
|
||||
op_name="quant_dequant_mxfp4",
|
||||
op_func=_quant_dequant_mxfp4,
|
||||
mutates_args=[],
|
||||
fake_impl=_quant_dequant_mxfp4_fake,
|
||||
)
|
||||
quant_dequant_mxfp4 = torch.ops.vllm.quant_dequant_mxfp4
|
||||
|
||||
@ -218,9 +218,7 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_per_tensor_w8a8_scaled_mm_impl",
|
||||
op_func=rocm_per_tensor_w8a8_scaled_mm_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=rocm_per_tensor_w8a8_scaled_mm_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -147,5 +147,4 @@ direct_register_custom_op(
|
||||
op_func=_flashinfer_rotary_embedding,
|
||||
mutates_args=["query", "key"], # These tensors are modified in-place
|
||||
fake_impl=_flashinfer_rotary_embedding_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
@ -136,9 +136,7 @@ def rocm_unquantized_gemm(layer: torch.nn.Module,
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_unquantized_gemm_impl",
|
||||
op_func=rocm_unquantized_gemm_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=rocm_unquantized_gemm_impl_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -56,7 +56,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import cdiv, direct_register_custom_op
|
||||
|
||||
@ -141,9 +140,7 @@ def sequence_parallel_chunk_fake(x: torch.Tensor) -> torch.Tensor:
|
||||
direct_register_custom_op(
|
||||
op_name="sequence_parallel_chunk",
|
||||
op_func=sequence_parallel_chunk,
|
||||
mutates_args=[],
|
||||
fake_impl=sequence_parallel_chunk_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||
)
|
||||
|
||||
|
||||
@ -48,7 +48,6 @@ from vllm.model_executor.models.utils import (
|
||||
is_pp_missing_parameter, make_empty_intermediate_tensors_factory,
|
||||
make_layers, maybe_prefix)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
|
||||
@ -490,7 +489,6 @@ direct_register_custom_op(
|
||||
op_func=plamo2_mamba_mixer,
|
||||
mutates_args=["output"],
|
||||
fake_impl=plamo2_mamba_mixer_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -1225,7 +1225,6 @@ direct_register_custom_op(
|
||||
op_func=gdn_attention,
|
||||
mutates_args=["output"],
|
||||
fake_impl=gdn_attention_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -2546,10 +2546,10 @@ vllm_lib = Library("vllm", "FRAGMENT") # noqa
|
||||
def direct_register_custom_op(
|
||||
op_name: str,
|
||||
op_func: Callable,
|
||||
mutates_args: list[str],
|
||||
mutates_args: Optional[list[str]] = None,
|
||||
fake_impl: Optional[Callable] = None,
|
||||
target_lib: Optional[Library] = None,
|
||||
dispatch_key: str = "CUDA",
|
||||
dispatch_key: Optional[str] = None,
|
||||
tags: tuple[torch.Tag, ...] = (),
|
||||
):
|
||||
"""
|
||||
@ -2577,6 +2577,13 @@ def direct_register_custom_op(
|
||||
"the required dependencies.")
|
||||
return
|
||||
|
||||
if mutates_args is None:
|
||||
mutates_args = []
|
||||
|
||||
if dispatch_key is None:
|
||||
from vllm.platforms import current_platform
|
||||
dispatch_key = current_platform.dispatch_key
|
||||
|
||||
import torch.library
|
||||
if hasattr(torch.library, "infer_schema"):
|
||||
schema_str = torch.library.infer_schema(op_func,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user