mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:54:56 +08:00
[Misc] Add CustomOp interface for device portability (#5255)
This commit is contained in:
parent
974fc9b845
commit
41ca62cf03
@ -44,7 +44,7 @@ def test_act_and_mul(
|
||||
elif activation == "gelu_tanh":
|
||||
layer = GeluAndMul(approximate="tanh")
|
||||
out = layer(x)
|
||||
ref_out = layer._forward(x)
|
||||
ref_out = layer.forward_native(x)
|
||||
# The SiLU and GELU implementations are equivalent to the native PyTorch
|
||||
# implementations, so we can do exact comparison.
|
||||
assert torch.allclose(out, ref_out, atol=0.0, rtol=0.0)
|
||||
@ -72,7 +72,7 @@ def test_activation(
|
||||
x = torch.randn(num_tokens, d, dtype=dtype)
|
||||
layer = activation()
|
||||
out = layer(x)
|
||||
ref_out = layer._forward(x)
|
||||
ref_out = layer.forward_native(x)
|
||||
assert torch.allclose(out,
|
||||
ref_out,
|
||||
atol=get_default_atol(out),
|
||||
|
||||
@ -42,7 +42,7 @@ def test_rms_norm(
|
||||
|
||||
# NOTE(woosuk): The reference implementation should be executed first
|
||||
# because the custom kernel is in-place.
|
||||
ref_out = layer._forward(x, residual)
|
||||
ref_out = layer.forward_native(x, residual)
|
||||
out = layer(x, residual)
|
||||
# NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
|
||||
# numerical errors than other operators because they involve reductions.
|
||||
|
||||
@ -64,7 +64,7 @@ def test_rotary_embedding(
|
||||
|
||||
# NOTE(woosuk): The reference implementation should be executed first
|
||||
# because the custom kernel is in-place.
|
||||
ref_query, ref_key = rope._forward(positions, query, key)
|
||||
ref_query, ref_key = rope.forward_native(positions, query, key)
|
||||
out_query, out_key = rope.forward(positions, query, key)
|
||||
# Compare the results.
|
||||
assert torch.allclose(out_query,
|
||||
@ -121,7 +121,7 @@ def test_batched_rotary_embedding(
|
||||
|
||||
# NOTE(woosuk): The reference implementation should be executed first
|
||||
# because the custom kernel is in-place.
|
||||
ref_query, ref_key = rope._forward(positions, query, key)
|
||||
ref_query, ref_key = rope.forward_native(positions, query, key)
|
||||
out_query, out_key = rope.forward(positions,
|
||||
query,
|
||||
key,
|
||||
@ -195,7 +195,8 @@ def test_batched_rotary_embedding_multi_lora(
|
||||
|
||||
# NOTE(woosuk): The reference implementation should be executed first
|
||||
# because the custom kernel is in-place.
|
||||
ref_query, ref_key = rope._forward(positions, query, key, query_offsets)
|
||||
ref_query, ref_key = rope.forward_native(positions, query, key,
|
||||
query_offsets)
|
||||
out_query, out_key = rope.forward(positions, query, key,
|
||||
query_offsets.flatten())
|
||||
# Compare the results.
|
||||
|
||||
60
vllm/model_executor/custom_op.py
Normal file
60
vllm/model_executor/custom_op.py
Normal file
@ -0,0 +1,60 @@
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.utils import is_cpu, is_hip
|
||||
|
||||
|
||||
class CustomOp(nn.Module):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
self._forward_method = self.dispatch_forward()
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self._forward_method(*args, **kwargs)
|
||||
|
||||
def forward_native(self, *args, **kwargs):
|
||||
"""PyTorch-native implementation of the forward method.
|
||||
|
||||
This method is optional. If implemented, it can be used with compilers
|
||||
such as torch.compile or PyTorch XLA. Also, it can be used for testing
|
||||
purposes.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def forward_cuda(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward_hip(self, *args, **kwargs):
|
||||
# By default, we assume that HIP ops are compatible with CUDA ops.
|
||||
return self.forward_cuda(*args, **kwargs)
|
||||
|
||||
def forward_xpu(self, *args, **kwargs):
|
||||
# By default, we assume that XPU ops are compatible with CUDA ops.
|
||||
# NOTE(woosuk): This is a placeholder for future extensions.
|
||||
return self.forward_cuda(*args, **kwargs)
|
||||
|
||||
def forward_cpu(self, *args, **kwargs):
|
||||
# By default, we assume that CPU ops are compatible with CUDA ops.
|
||||
return self.forward_cuda(*args, **kwargs)
|
||||
|
||||
def forward_tpu(self, *args, **kwargs):
|
||||
# By default, we assume that TPU ops are compatible with the
|
||||
# PyTorch-native implementation.
|
||||
# NOTE(woosuk): This is a placeholder for future extensions.
|
||||
return self.forward_native(*args, **kwargs)
|
||||
|
||||
def forward_gaudi(self, *args, **kwargs):
|
||||
# By default, we assume that Gaudi ops are compatible with the
|
||||
# PyTorch-native implementation.
|
||||
# NOTE(woosuk): This is a placeholder for future extensions.
|
||||
return self.forward_native(*args, **kwargs)
|
||||
|
||||
def dispatch_forward(self):
|
||||
# NOTE(woosuk): Here we assume that vLLM was built for only one
|
||||
# specific backend. Currently, we do not support dynamic dispatching.
|
||||
if is_hip():
|
||||
return self.forward_hip
|
||||
elif is_cpu():
|
||||
return self.forward_cpu
|
||||
else:
|
||||
return self.forward_cuda
|
||||
@ -6,14 +6,14 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
|
||||
class SiluAndMul(nn.Module):
|
||||
class SiluAndMul(CustomOp):
|
||||
"""An activation function for SwiGLU.
|
||||
|
||||
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
|
||||
@ -23,12 +23,14 @@ class SiluAndMul(nn.Module):
|
||||
return: (num_tokens, d) or (batch_size, seq_len, d)
|
||||
"""
|
||||
|
||||
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
d = x.shape[-1] // 2
|
||||
return F.silu(x[..., :d]) * x[..., d:]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
@ -36,7 +38,7 @@ class SiluAndMul(nn.Module):
|
||||
return out
|
||||
|
||||
|
||||
class GeluAndMul(nn.Module):
|
||||
class GeluAndMul(CustomOp):
|
||||
"""An activation function for GeGLU.
|
||||
|
||||
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
|
||||
@ -52,12 +54,14 @@ class GeluAndMul(nn.Module):
|
||||
if approximate not in ("none", "tanh"):
|
||||
raise ValueError(f"Unknown approximate mode: {approximate}")
|
||||
|
||||
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
d = x.shape[-1] // 2
|
||||
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
@ -71,28 +75,32 @@ class GeluAndMul(nn.Module):
|
||||
return f'approximate={repr(self.approximate)}'
|
||||
|
||||
|
||||
class NewGELU(nn.Module):
|
||||
class NewGELU(CustomOp):
|
||||
|
||||
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
c = math.sqrt(2.0 / math.pi)
|
||||
return 0.5 * x * (1.0 + torch.tanh(c *
|
||||
(x + 0.044715 * torch.pow(x, 3.0))))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
out = torch.empty_like(x)
|
||||
ops.gelu_new(out, x)
|
||||
return out
|
||||
|
||||
|
||||
class FastGELU(nn.Module):
|
||||
class FastGELU(CustomOp):
|
||||
|
||||
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
|
||||
(1.0 + 0.044715 * x * x)))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
out = torch.empty_like(x)
|
||||
ops.gelu_fast(out, x)
|
||||
return out
|
||||
|
||||
@ -4,10 +4,10 @@ from typing import Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
class RMSNorm(CustomOp):
|
||||
"""Root mean square normalization.
|
||||
|
||||
Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
|
||||
@ -23,7 +23,7 @@ class RMSNorm(nn.Module):
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def _forward(
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
@ -43,11 +43,13 @@ class RMSNorm(nn.Module):
|
||||
else:
|
||||
return x, residual
|
||||
|
||||
def forward(
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
if residual is not None:
|
||||
ops.fused_add_rms_norm(
|
||||
x,
|
||||
|
||||
@ -27,7 +27,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
|
||||
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
||||
@ -43,7 +43,7 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
|
||||
return x.flatten(-2)
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
class RotaryEmbedding(CustomOp):
|
||||
"""Original rotary positional embedding."""
|
||||
|
||||
def __init__(
|
||||
@ -93,7 +93,7 @@ class RotaryEmbedding(nn.Module):
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
def _forward(
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
@ -138,13 +138,15 @@ class RotaryEmbedding(nn.Module):
|
||||
key = key.flatten(-2)
|
||||
return query, key
|
||||
|
||||
def forward(
|
||||
def forward_cuda(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
|
||||
dtype=query.dtype)
|
||||
# ops.rotary_embedding()/batched_rotary_embedding()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user