[optimization] remove python function call for custom op (#11750)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2025-01-08 01:04:28 +08:00 committed by GitHub
parent c0efe92d8b
commit 869579a702
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 15 additions and 13 deletions

View File

@ -35,10 +35,6 @@ else:
# activation ops
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.silu_and_mul(out, x)
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_and_mul(out, x)

View File

@ -10,6 +10,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import LazyDict
@ -58,27 +59,31 @@ class SiluAndMul(CustomOp):
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
def __init__(self):
super().__init__()
if current_platform.is_cuda_alike():
self.op = torch.ops._C.silu_and_mul
elif current_platform.is_xpu():
import intel_extension_for_pytorch as ipex
self.op = ipex.llm.functional.silu_and_mul
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.silu_and_mul(out, x)
self.op(out, x)
return out
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.silu_and_mul(out, x)
self.op(out, x)
return out

View File

@ -4,7 +4,6 @@ from typing import Optional
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, moe_align_block_size, try_get_optimal_moe_config)
from vllm.scalar_type import scalar_types
@ -301,7 +300,8 @@ def fused_marlin_moe(
False,
)
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N))
torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, 2 * N))
intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe(
intermediate_cache2,

View File

@ -753,7 +753,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_int8_w8a16=use_int8_w8a16,
block_shape=block_shape)
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, N))
invoke_fused_moe_kernel(intermediate_cache2,
w2,