[misc] remove python function call for custom activation op (#11885)

Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
cennn 2025-01-10 17:18:25 +08:00 committed by GitHub
parent d53575a5f0
commit d907be7dc7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 46 additions and 60 deletions

View File

@ -34,33 +34,6 @@ else:
from torch.library import impl_abstract as register_fake
# activation ops
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_and_mul(out, x)
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_tanh_and_mul(out, x)
def fatrelu_and_mul(out: torch.Tensor,
x: torch.Tensor,
threshold: float = 0.0) -> None:
torch.ops._C.fatrelu_and_mul(out, x, threshold)
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_fast(out, x)
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_new(out, x)
def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_quick(out, x)
# page attention ops
def paged_attention_v1(
out: torch.Tensor,

View File

@ -30,6 +30,8 @@ class FatreluAndMul(CustomOp):
def __init__(self, threshold: float = 0.):
super().__init__()
self.threshold = threshold
if current_platform.is_cuda_alike() or current_platform.is_cpu():
self.op = torch.ops._C.fatrelu_and_mul
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
@ -39,12 +41,10 @@ class FatreluAndMul(CustomOp):
return x1 * x2
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.fatrelu_and_mul(out, x, self.threshold)
self.op(out, x, self.threshold)
return out
@ -103,6 +103,17 @@ class GeluAndMul(CustomOp):
self.approximate = approximate
if approximate not in ("none", "tanh"):
raise ValueError(f"Unknown approximate mode: {approximate}")
if current_platform.is_cuda_alike() or current_platform.is_cpu():
if approximate == "none":
self.op = torch.ops._C.gelu_and_mul
elif approximate == "tanh":
self.op = torch.ops._C.gelu_tanh_and_mul
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops
if approximate == "none":
self.op = ipex_ops.gelu_and_mul
else:
self.op = ipex_ops.gelu_tanh_and_mul
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
@ -110,27 +121,17 @@ class GeluAndMul(CustomOp):
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if self.approximate == "none":
ops.gelu_and_mul(out, x)
elif self.approximate == "tanh":
ops.gelu_tanh_and_mul(out, x)
self.op(out, x)
return out
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if self.approximate == "none":
ops.gelu_and_mul(out, x)
elif self.approximate == "tanh":
ops.gelu_tanh_and_mul(out, x)
self.op(out, x)
return out
def extra_repr(self) -> str:
@ -140,6 +141,14 @@ class GeluAndMul(CustomOp):
@CustomOp.register("gelu_new")
class NewGELU(CustomOp):
def __init__(self):
super().__init__()
if current_platform.is_cuda_alike() or current_platform.is_cpu():
self.op = torch.ops._C.gelu_new
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops
self.op = ipex_ops.gelu_new
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
c = math.sqrt(2.0 / math.pi)
@ -147,58 +156,62 @@ class NewGELU(CustomOp):
(x + 0.044715 * torch.pow(x, 3.0))))
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops
out = torch.empty_like(x)
ops.gelu_new(out, x)
self.op(out, x)
return out
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops
return ops.gelu_new(x)
return self.op(x)
@CustomOp.register("gelu_fast")
class FastGELU(CustomOp):
def __init__(self):
super().__init__()
if current_platform.is_cuda_alike() or current_platform.is_cpu():
self.op = torch.ops._C.gelu_fast
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops
self.op = ipex_ops.gelu_fast
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
(1.0 + 0.044715 * x * x)))
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops
out = torch.empty_like(x)
ops.gelu_fast(out, x)
self.op(out, x)
return out
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops
return ops.gelu_fast(x)
return self.op(x)
@CustomOp.register("quick_gelu")
class QuickGELU(CustomOp):
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
def __init__(self):
super().__init__()
if current_platform.is_cuda_alike() or current_platform.is_cpu():
self.op = torch.ops._C.gelu_quick
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops
self.op = ipex_ops.gelu_quick
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
return x * torch.sigmoid(1.702 * x)
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops
out = torch.empty_like(x)
ops.gelu_quick(out, x)
self.op(out, x)
return out
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops
out = torch.empty_like(x)
ops.gelu_quick(out, x)
self.op(out, x)
return out
# TODO implement forward_xpu for QuickGELU