mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 18:15:35 +08:00
[Misc] Add CustomOp interface for device portability (#5255)
This commit is contained in:
parent
974fc9b845
commit
41ca62cf03
@ -44,7 +44,7 @@ def test_act_and_mul(
|
|||||||
elif activation == "gelu_tanh":
|
elif activation == "gelu_tanh":
|
||||||
layer = GeluAndMul(approximate="tanh")
|
layer = GeluAndMul(approximate="tanh")
|
||||||
out = layer(x)
|
out = layer(x)
|
||||||
ref_out = layer._forward(x)
|
ref_out = layer.forward_native(x)
|
||||||
# The SiLU and GELU implementations are equivalent to the native PyTorch
|
# The SiLU and GELU implementations are equivalent to the native PyTorch
|
||||||
# implementations, so we can do exact comparison.
|
# implementations, so we can do exact comparison.
|
||||||
assert torch.allclose(out, ref_out, atol=0.0, rtol=0.0)
|
assert torch.allclose(out, ref_out, atol=0.0, rtol=0.0)
|
||||||
@ -72,7 +72,7 @@ def test_activation(
|
|||||||
x = torch.randn(num_tokens, d, dtype=dtype)
|
x = torch.randn(num_tokens, d, dtype=dtype)
|
||||||
layer = activation()
|
layer = activation()
|
||||||
out = layer(x)
|
out = layer(x)
|
||||||
ref_out = layer._forward(x)
|
ref_out = layer.forward_native(x)
|
||||||
assert torch.allclose(out,
|
assert torch.allclose(out,
|
||||||
ref_out,
|
ref_out,
|
||||||
atol=get_default_atol(out),
|
atol=get_default_atol(out),
|
||||||
|
|||||||
@ -42,7 +42,7 @@ def test_rms_norm(
|
|||||||
|
|
||||||
# NOTE(woosuk): The reference implementation should be executed first
|
# NOTE(woosuk): The reference implementation should be executed first
|
||||||
# because the custom kernel is in-place.
|
# because the custom kernel is in-place.
|
||||||
ref_out = layer._forward(x, residual)
|
ref_out = layer.forward_native(x, residual)
|
||||||
out = layer(x, residual)
|
out = layer(x, residual)
|
||||||
# NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
|
# NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
|
||||||
# numerical errors than other operators because they involve reductions.
|
# numerical errors than other operators because they involve reductions.
|
||||||
|
|||||||
@ -64,7 +64,7 @@ def test_rotary_embedding(
|
|||||||
|
|
||||||
# NOTE(woosuk): The reference implementation should be executed first
|
# NOTE(woosuk): The reference implementation should be executed first
|
||||||
# because the custom kernel is in-place.
|
# because the custom kernel is in-place.
|
||||||
ref_query, ref_key = rope._forward(positions, query, key)
|
ref_query, ref_key = rope.forward_native(positions, query, key)
|
||||||
out_query, out_key = rope.forward(positions, query, key)
|
out_query, out_key = rope.forward(positions, query, key)
|
||||||
# Compare the results.
|
# Compare the results.
|
||||||
assert torch.allclose(out_query,
|
assert torch.allclose(out_query,
|
||||||
@ -121,7 +121,7 @@ def test_batched_rotary_embedding(
|
|||||||
|
|
||||||
# NOTE(woosuk): The reference implementation should be executed first
|
# NOTE(woosuk): The reference implementation should be executed first
|
||||||
# because the custom kernel is in-place.
|
# because the custom kernel is in-place.
|
||||||
ref_query, ref_key = rope._forward(positions, query, key)
|
ref_query, ref_key = rope.forward_native(positions, query, key)
|
||||||
out_query, out_key = rope.forward(positions,
|
out_query, out_key = rope.forward(positions,
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
@ -195,7 +195,8 @@ def test_batched_rotary_embedding_multi_lora(
|
|||||||
|
|
||||||
# NOTE(woosuk): The reference implementation should be executed first
|
# NOTE(woosuk): The reference implementation should be executed first
|
||||||
# because the custom kernel is in-place.
|
# because the custom kernel is in-place.
|
||||||
ref_query, ref_key = rope._forward(positions, query, key, query_offsets)
|
ref_query, ref_key = rope.forward_native(positions, query, key,
|
||||||
|
query_offsets)
|
||||||
out_query, out_key = rope.forward(positions, query, key,
|
out_query, out_key = rope.forward(positions, query, key,
|
||||||
query_offsets.flatten())
|
query_offsets.flatten())
|
||||||
# Compare the results.
|
# Compare the results.
|
||||||
|
|||||||
60
vllm/model_executor/custom_op.py
Normal file
60
vllm/model_executor/custom_op.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm.utils import is_cpu, is_hip
|
||||||
|
|
||||||
|
|
||||||
|
class CustomOp(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self._forward_method = self.dispatch_forward()
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return self._forward_method(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward_native(self, *args, **kwargs):
|
||||||
|
"""PyTorch-native implementation of the forward method.
|
||||||
|
|
||||||
|
This method is optional. If implemented, it can be used with compilers
|
||||||
|
such as torch.compile or PyTorch XLA. Also, it can be used for testing
|
||||||
|
purposes.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def forward_cuda(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def forward_hip(self, *args, **kwargs):
|
||||||
|
# By default, we assume that HIP ops are compatible with CUDA ops.
|
||||||
|
return self.forward_cuda(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward_xpu(self, *args, **kwargs):
|
||||||
|
# By default, we assume that XPU ops are compatible with CUDA ops.
|
||||||
|
# NOTE(woosuk): This is a placeholder for future extensions.
|
||||||
|
return self.forward_cuda(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward_cpu(self, *args, **kwargs):
|
||||||
|
# By default, we assume that CPU ops are compatible with CUDA ops.
|
||||||
|
return self.forward_cuda(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward_tpu(self, *args, **kwargs):
|
||||||
|
# By default, we assume that TPU ops are compatible with the
|
||||||
|
# PyTorch-native implementation.
|
||||||
|
# NOTE(woosuk): This is a placeholder for future extensions.
|
||||||
|
return self.forward_native(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward_gaudi(self, *args, **kwargs):
|
||||||
|
# By default, we assume that Gaudi ops are compatible with the
|
||||||
|
# PyTorch-native implementation.
|
||||||
|
# NOTE(woosuk): This is a placeholder for future extensions.
|
||||||
|
return self.forward_native(*args, **kwargs)
|
||||||
|
|
||||||
|
def dispatch_forward(self):
|
||||||
|
# NOTE(woosuk): Here we assume that vLLM was built for only one
|
||||||
|
# specific backend. Currently, we do not support dynamic dispatching.
|
||||||
|
if is_hip():
|
||||||
|
return self.forward_hip
|
||||||
|
elif is_cpu():
|
||||||
|
return self.forward_cpu
|
||||||
|
else:
|
||||||
|
return self.forward_cuda
|
||||||
@ -6,14 +6,14 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
|
||||||
|
|
||||||
class SiluAndMul(nn.Module):
|
class SiluAndMul(CustomOp):
|
||||||
"""An activation function for SwiGLU.
|
"""An activation function for SwiGLU.
|
||||||
|
|
||||||
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
|
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
|
||||||
@ -23,12 +23,14 @@ class SiluAndMul(nn.Module):
|
|||||||
return: (num_tokens, d) or (batch_size, seq_len, d)
|
return: (num_tokens, d) or (batch_size, seq_len, d)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""PyTorch-native implementation equivalent to forward()."""
|
"""PyTorch-native implementation equivalent to forward()."""
|
||||||
d = x.shape[-1] // 2
|
d = x.shape[-1] // 2
|
||||||
return F.silu(x[..., :d]) * x[..., d:]
|
return F.silu(x[..., :d]) * x[..., d:]
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
d = x.shape[-1] // 2
|
d = x.shape[-1] // 2
|
||||||
output_shape = (x.shape[:-1] + (d, ))
|
output_shape = (x.shape[:-1] + (d, ))
|
||||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||||
@ -36,7 +38,7 @@ class SiluAndMul(nn.Module):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class GeluAndMul(nn.Module):
|
class GeluAndMul(CustomOp):
|
||||||
"""An activation function for GeGLU.
|
"""An activation function for GeGLU.
|
||||||
|
|
||||||
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
|
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
|
||||||
@ -52,12 +54,14 @@ class GeluAndMul(nn.Module):
|
|||||||
if approximate not in ("none", "tanh"):
|
if approximate not in ("none", "tanh"):
|
||||||
raise ValueError(f"Unknown approximate mode: {approximate}")
|
raise ValueError(f"Unknown approximate mode: {approximate}")
|
||||||
|
|
||||||
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""PyTorch-native implementation equivalent to forward()."""
|
"""PyTorch-native implementation equivalent to forward()."""
|
||||||
d = x.shape[-1] // 2
|
d = x.shape[-1] // 2
|
||||||
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
|
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
d = x.shape[-1] // 2
|
d = x.shape[-1] // 2
|
||||||
output_shape = (x.shape[:-1] + (d, ))
|
output_shape = (x.shape[:-1] + (d, ))
|
||||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||||
@ -71,28 +75,32 @@ class GeluAndMul(nn.Module):
|
|||||||
return f'approximate={repr(self.approximate)}'
|
return f'approximate={repr(self.approximate)}'
|
||||||
|
|
||||||
|
|
||||||
class NewGELU(nn.Module):
|
class NewGELU(CustomOp):
|
||||||
|
|
||||||
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""PyTorch-native implementation equivalent to forward()."""
|
"""PyTorch-native implementation equivalent to forward()."""
|
||||||
c = math.sqrt(2.0 / math.pi)
|
c = math.sqrt(2.0 / math.pi)
|
||||||
return 0.5 * x * (1.0 + torch.tanh(c *
|
return 0.5 * x * (1.0 + torch.tanh(c *
|
||||||
(x + 0.044715 * torch.pow(x, 3.0))))
|
(x + 0.044715 * torch.pow(x, 3.0))))
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
out = torch.empty_like(x)
|
out = torch.empty_like(x)
|
||||||
ops.gelu_new(out, x)
|
ops.gelu_new(out, x)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class FastGELU(nn.Module):
|
class FastGELU(CustomOp):
|
||||||
|
|
||||||
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""PyTorch-native implementation equivalent to forward()."""
|
"""PyTorch-native implementation equivalent to forward()."""
|
||||||
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
|
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
|
||||||
(1.0 + 0.044715 * x * x)))
|
(1.0 + 0.044715 * x * x)))
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
out = torch.empty_like(x)
|
out = torch.empty_like(x)
|
||||||
ops.gelu_fast(out, x)
|
ops.gelu_fast(out, x)
|
||||||
return out
|
return out
|
||||||
|
|||||||
@ -4,10 +4,10 @@ from typing import Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(CustomOp):
|
||||||
"""Root mean square normalization.
|
"""Root mean square normalization.
|
||||||
|
|
||||||
Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
|
Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
|
||||||
@ -23,7 +23,7 @@ class RMSNorm(nn.Module):
|
|||||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||||
self.variance_epsilon = eps
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
def _forward(
|
def forward_native(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
residual: Optional[torch.Tensor] = None,
|
residual: Optional[torch.Tensor] = None,
|
||||||
@ -43,11 +43,13 @@ class RMSNorm(nn.Module):
|
|||||||
else:
|
else:
|
||||||
return x, residual
|
return x, residual
|
||||||
|
|
||||||
def forward(
|
def forward_cuda(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
residual: Optional[torch.Tensor] = None,
|
residual: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
ops.fused_add_rms_norm(
|
ops.fused_add_rms_norm(
|
||||||
x,
|
x,
|
||||||
|
|||||||
@ -27,7 +27,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
|
||||||
|
|
||||||
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
||||||
@ -43,7 +43,7 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
|
|||||||
return x.flatten(-2)
|
return x.flatten(-2)
|
||||||
|
|
||||||
|
|
||||||
class RotaryEmbedding(nn.Module):
|
class RotaryEmbedding(CustomOp):
|
||||||
"""Original rotary positional embedding."""
|
"""Original rotary positional embedding."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -93,7 +93,7 @@ class RotaryEmbedding(nn.Module):
|
|||||||
cache = torch.cat((cos, sin), dim=-1)
|
cache = torch.cat((cos, sin), dim=-1)
|
||||||
return cache
|
return cache
|
||||||
|
|
||||||
def _forward(
|
def forward_native(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@ -138,13 +138,15 @@ class RotaryEmbedding(nn.Module):
|
|||||||
key = key.flatten(-2)
|
key = key.flatten(-2)
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
def forward(
|
def forward_cuda(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
offsets: Optional[torch.Tensor] = None,
|
offsets: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
|
self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
|
||||||
dtype=query.dtype)
|
dtype=query.dtype)
|
||||||
# ops.rotary_embedding()/batched_rotary_embedding()
|
# ops.rotary_embedding()/batched_rotary_embedding()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user