mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 00:25:01 +08:00
[misc] remove python function call for custom activation op (#11885)
Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
d53575a5f0
commit
d907be7dc7
@ -34,33 +34,6 @@ else:
|
|||||||
from torch.library import impl_abstract as register_fake
|
from torch.library import impl_abstract as register_fake
|
||||||
|
|
||||||
|
|
||||||
# activation ops
|
|
||||||
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
|
||||||
torch.ops._C.gelu_and_mul(out, x)
|
|
||||||
|
|
||||||
|
|
||||||
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
|
||||||
torch.ops._C.gelu_tanh_and_mul(out, x)
|
|
||||||
|
|
||||||
|
|
||||||
def fatrelu_and_mul(out: torch.Tensor,
|
|
||||||
x: torch.Tensor,
|
|
||||||
threshold: float = 0.0) -> None:
|
|
||||||
torch.ops._C.fatrelu_and_mul(out, x, threshold)
|
|
||||||
|
|
||||||
|
|
||||||
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
|
|
||||||
torch.ops._C.gelu_fast(out, x)
|
|
||||||
|
|
||||||
|
|
||||||
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
|
|
||||||
torch.ops._C.gelu_new(out, x)
|
|
||||||
|
|
||||||
|
|
||||||
def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
|
|
||||||
torch.ops._C.gelu_quick(out, x)
|
|
||||||
|
|
||||||
|
|
||||||
# page attention ops
|
# page attention ops
|
||||||
def paged_attention_v1(
|
def paged_attention_v1(
|
||||||
out: torch.Tensor,
|
out: torch.Tensor,
|
||||||
|
|||||||
@ -30,6 +30,8 @@ class FatreluAndMul(CustomOp):
|
|||||||
def __init__(self, threshold: float = 0.):
|
def __init__(self, threshold: float = 0.):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.threshold = threshold
|
self.threshold = threshold
|
||||||
|
if current_platform.is_cuda_alike() or current_platform.is_cpu():
|
||||||
|
self.op = torch.ops._C.fatrelu_and_mul
|
||||||
|
|
||||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
d = x.shape[-1] // 2
|
d = x.shape[-1] // 2
|
||||||
@ -39,12 +41,10 @@ class FatreluAndMul(CustomOp):
|
|||||||
return x1 * x2
|
return x1 * x2
|
||||||
|
|
||||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
|
|
||||||
d = x.shape[-1] // 2
|
d = x.shape[-1] // 2
|
||||||
output_shape = (x.shape[:-1] + (d, ))
|
output_shape = (x.shape[:-1] + (d, ))
|
||||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||||
ops.fatrelu_and_mul(out, x, self.threshold)
|
self.op(out, x, self.threshold)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@ -103,6 +103,17 @@ class GeluAndMul(CustomOp):
|
|||||||
self.approximate = approximate
|
self.approximate = approximate
|
||||||
if approximate not in ("none", "tanh"):
|
if approximate not in ("none", "tanh"):
|
||||||
raise ValueError(f"Unknown approximate mode: {approximate}")
|
raise ValueError(f"Unknown approximate mode: {approximate}")
|
||||||
|
if current_platform.is_cuda_alike() or current_platform.is_cpu():
|
||||||
|
if approximate == "none":
|
||||||
|
self.op = torch.ops._C.gelu_and_mul
|
||||||
|
elif approximate == "tanh":
|
||||||
|
self.op = torch.ops._C.gelu_tanh_and_mul
|
||||||
|
elif current_platform.is_xpu():
|
||||||
|
from vllm._ipex_ops import ipex_ops
|
||||||
|
if approximate == "none":
|
||||||
|
self.op = ipex_ops.gelu_and_mul
|
||||||
|
else:
|
||||||
|
self.op = ipex_ops.gelu_tanh_and_mul
|
||||||
|
|
||||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""PyTorch-native implementation equivalent to forward()."""
|
"""PyTorch-native implementation equivalent to forward()."""
|
||||||
@ -110,27 +121,17 @@ class GeluAndMul(CustomOp):
|
|||||||
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
|
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
|
||||||
|
|
||||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
|
|
||||||
d = x.shape[-1] // 2
|
d = x.shape[-1] // 2
|
||||||
output_shape = (x.shape[:-1] + (d, ))
|
output_shape = (x.shape[:-1] + (d, ))
|
||||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||||
if self.approximate == "none":
|
self.op(out, x)
|
||||||
ops.gelu_and_mul(out, x)
|
|
||||||
elif self.approximate == "tanh":
|
|
||||||
ops.gelu_tanh_and_mul(out, x)
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
from vllm._ipex_ops import ipex_ops as ops
|
|
||||||
|
|
||||||
d = x.shape[-1] // 2
|
d = x.shape[-1] // 2
|
||||||
output_shape = (x.shape[:-1] + (d, ))
|
output_shape = (x.shape[:-1] + (d, ))
|
||||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||||
if self.approximate == "none":
|
self.op(out, x)
|
||||||
ops.gelu_and_mul(out, x)
|
|
||||||
elif self.approximate == "tanh":
|
|
||||||
ops.gelu_tanh_and_mul(out, x)
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
@ -140,6 +141,14 @@ class GeluAndMul(CustomOp):
|
|||||||
@CustomOp.register("gelu_new")
|
@CustomOp.register("gelu_new")
|
||||||
class NewGELU(CustomOp):
|
class NewGELU(CustomOp):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
if current_platform.is_cuda_alike() or current_platform.is_cpu():
|
||||||
|
self.op = torch.ops._C.gelu_new
|
||||||
|
elif current_platform.is_xpu():
|
||||||
|
from vllm._ipex_ops import ipex_ops
|
||||||
|
self.op = ipex_ops.gelu_new
|
||||||
|
|
||||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""PyTorch-native implementation equivalent to forward()."""
|
"""PyTorch-native implementation equivalent to forward()."""
|
||||||
c = math.sqrt(2.0 / math.pi)
|
c = math.sqrt(2.0 / math.pi)
|
||||||
@ -147,58 +156,62 @@ class NewGELU(CustomOp):
|
|||||||
(x + 0.044715 * torch.pow(x, 3.0))))
|
(x + 0.044715 * torch.pow(x, 3.0))))
|
||||||
|
|
||||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
|
|
||||||
out = torch.empty_like(x)
|
out = torch.empty_like(x)
|
||||||
ops.gelu_new(out, x)
|
self.op(out, x)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
from vllm._ipex_ops import ipex_ops as ops
|
return self.op(x)
|
||||||
|
|
||||||
return ops.gelu_new(x)
|
|
||||||
|
|
||||||
|
|
||||||
@CustomOp.register("gelu_fast")
|
@CustomOp.register("gelu_fast")
|
||||||
class FastGELU(CustomOp):
|
class FastGELU(CustomOp):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
if current_platform.is_cuda_alike() or current_platform.is_cpu():
|
||||||
|
self.op = torch.ops._C.gelu_fast
|
||||||
|
elif current_platform.is_xpu():
|
||||||
|
from vllm._ipex_ops import ipex_ops
|
||||||
|
self.op = ipex_ops.gelu_fast
|
||||||
|
|
||||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""PyTorch-native implementation equivalent to forward()."""
|
"""PyTorch-native implementation equivalent to forward()."""
|
||||||
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
|
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
|
||||||
(1.0 + 0.044715 * x * x)))
|
(1.0 + 0.044715 * x * x)))
|
||||||
|
|
||||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
|
|
||||||
out = torch.empty_like(x)
|
out = torch.empty_like(x)
|
||||||
ops.gelu_fast(out, x)
|
self.op(out, x)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
from vllm._ipex_ops import ipex_ops as ops
|
return self.op(x)
|
||||||
|
|
||||||
return ops.gelu_fast(x)
|
|
||||||
|
|
||||||
|
|
||||||
@CustomOp.register("quick_gelu")
|
@CustomOp.register("quick_gelu")
|
||||||
class QuickGELU(CustomOp):
|
class QuickGELU(CustomOp):
|
||||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
if current_platform.is_cuda_alike() or current_platform.is_cpu():
|
||||||
|
self.op = torch.ops._C.gelu_quick
|
||||||
|
elif current_platform.is_xpu():
|
||||||
|
from vllm._ipex_ops import ipex_ops
|
||||||
|
self.op = ipex_ops.gelu_quick
|
||||||
|
|
||||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""PyTorch-native implementation equivalent to forward()."""
|
"""PyTorch-native implementation equivalent to forward()."""
|
||||||
return x * torch.sigmoid(1.702 * x)
|
return x * torch.sigmoid(1.702 * x)
|
||||||
|
|
||||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
|
|
||||||
out = torch.empty_like(x)
|
out = torch.empty_like(x)
|
||||||
ops.gelu_quick(out, x)
|
self.op(out, x)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
from vllm._ipex_ops import ipex_ops as ops
|
|
||||||
|
|
||||||
out = torch.empty_like(x)
|
out = torch.empty_like(x)
|
||||||
ops.gelu_quick(out, x)
|
self.op(out, x)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
# TODO implement forward_xpu for QuickGELU
|
# TODO implement forward_xpu for QuickGELU
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user