[Misc] Add CustomOp interface for device portability (#5255)

This commit is contained in:
Woosuk Kwon 2024-06-05 09:18:19 -07:00 committed by GitHub
parent 974fc9b845
commit 41ca62cf03
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 100 additions and 27 deletions

View File

@ -44,7 +44,7 @@ def test_act_and_mul(
elif activation == "gelu_tanh":
layer = GeluAndMul(approximate="tanh")
out = layer(x)
ref_out = layer._forward(x)
ref_out = layer.forward_native(x)
# The SiLU and GELU implementations are equivalent to the native PyTorch
# implementations, so we can do exact comparison.
assert torch.allclose(out, ref_out, atol=0.0, rtol=0.0)
@ -72,7 +72,7 @@ def test_activation(
x = torch.randn(num_tokens, d, dtype=dtype)
layer = activation()
out = layer(x)
ref_out = layer._forward(x)
ref_out = layer.forward_native(x)
assert torch.allclose(out,
ref_out,
atol=get_default_atol(out),

View File

@ -42,7 +42,7 @@ def test_rms_norm(
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_out = layer._forward(x, residual)
ref_out = layer.forward_native(x, residual)
out = layer(x, residual)
# NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
# numerical errors than other operators because they involve reductions.

View File

@ -64,7 +64,7 @@ def test_rotary_embedding(
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_query, ref_key = rope._forward(positions, query, key)
ref_query, ref_key = rope.forward_native(positions, query, key)
out_query, out_key = rope.forward(positions, query, key)
# Compare the results.
assert torch.allclose(out_query,
@ -121,7 +121,7 @@ def test_batched_rotary_embedding(
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_query, ref_key = rope._forward(positions, query, key)
ref_query, ref_key = rope.forward_native(positions, query, key)
out_query, out_key = rope.forward(positions,
query,
key,
@ -195,7 +195,8 @@ def test_batched_rotary_embedding_multi_lora(
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_query, ref_key = rope._forward(positions, query, key, query_offsets)
ref_query, ref_key = rope.forward_native(positions, query, key,
query_offsets)
out_query, out_key = rope.forward(positions, query, key,
query_offsets.flatten())
# Compare the results.

View File

@ -0,0 +1,60 @@
import torch.nn as nn
from vllm.utils import is_cpu, is_hip
class CustomOp(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self._forward_method = self.dispatch_forward()
def forward(self, *args, **kwargs):
return self._forward_method(*args, **kwargs)
def forward_native(self, *args, **kwargs):
"""PyTorch-native implementation of the forward method.
This method is optional. If implemented, it can be used with compilers
such as torch.compile or PyTorch XLA. Also, it can be used for testing
purposes.
"""
raise NotImplementedError
def forward_cuda(self, *args, **kwargs):
raise NotImplementedError
def forward_hip(self, *args, **kwargs):
# By default, we assume that HIP ops are compatible with CUDA ops.
return self.forward_cuda(*args, **kwargs)
def forward_xpu(self, *args, **kwargs):
# By default, we assume that XPU ops are compatible with CUDA ops.
# NOTE(woosuk): This is a placeholder for future extensions.
return self.forward_cuda(*args, **kwargs)
def forward_cpu(self, *args, **kwargs):
# By default, we assume that CPU ops are compatible with CUDA ops.
return self.forward_cuda(*args, **kwargs)
def forward_tpu(self, *args, **kwargs):
# By default, we assume that TPU ops are compatible with the
# PyTorch-native implementation.
# NOTE(woosuk): This is a placeholder for future extensions.
return self.forward_native(*args, **kwargs)
def forward_gaudi(self, *args, **kwargs):
# By default, we assume that Gaudi ops are compatible with the
# PyTorch-native implementation.
# NOTE(woosuk): This is a placeholder for future extensions.
return self.forward_native(*args, **kwargs)
def dispatch_forward(self):
# NOTE(woosuk): Here we assume that vLLM was built for only one
# specific backend. Currently, we do not support dynamic dispatching.
if is_hip():
return self.forward_hip
elif is_cpu():
return self.forward_cpu
else:
return self.forward_cuda

View File

@ -6,14 +6,14 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm import _custom_ops as ops
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs
class SiluAndMul(nn.Module):
class SiluAndMul(CustomOp):
"""An activation function for SwiGLU.
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
@ -23,12 +23,14 @@ class SiluAndMul(nn.Module):
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
def _forward(self, x: torch.Tensor) -> torch.Tensor:
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
@ -36,7 +38,7 @@ class SiluAndMul(nn.Module):
return out
class GeluAndMul(nn.Module):
class GeluAndMul(CustomOp):
"""An activation function for GeGLU.
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
@ -52,12 +54,14 @@ class GeluAndMul(nn.Module):
if approximate not in ("none", "tanh"):
raise ValueError(f"Unknown approximate mode: {approximate}")
def _forward(self, x: torch.Tensor) -> torch.Tensor:
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
@ -71,28 +75,32 @@ class GeluAndMul(nn.Module):
return f'approximate={repr(self.approximate)}'
class NewGELU(nn.Module):
class NewGELU(CustomOp):
def _forward(self, x: torch.Tensor) -> torch.Tensor:
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
c = math.sqrt(2.0 / math.pi)
return 0.5 * x * (1.0 + torch.tanh(c *
(x + 0.044715 * torch.pow(x, 3.0))))
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops
out = torch.empty_like(x)
ops.gelu_new(out, x)
return out
class FastGELU(nn.Module):
class FastGELU(CustomOp):
def _forward(self, x: torch.Tensor) -> torch.Tensor:
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
(1.0 + 0.044715 * x * x)))
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops
out = torch.empty_like(x)
ops.gelu_fast(out, x)
return out

View File

@ -4,10 +4,10 @@ from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from vllm import _custom_ops as ops
from vllm.model_executor.custom_op import CustomOp
class RMSNorm(nn.Module):
class RMSNorm(CustomOp):
"""Root mean square normalization.
Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
@ -23,7 +23,7 @@ class RMSNorm(nn.Module):
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def _forward(
def forward_native(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
@ -43,11 +43,13 @@ class RMSNorm(nn.Module):
else:
return x, residual
def forward(
def forward_cuda(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
from vllm import _custom_ops as ops
if residual is not None:
ops.fused_add_rms_norm(
x,

View File

@ -27,7 +27,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from vllm import _custom_ops as ops
from vllm.model_executor.custom_op import CustomOp
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
@ -43,7 +43,7 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
return x.flatten(-2)
class RotaryEmbedding(nn.Module):
class RotaryEmbedding(CustomOp):
"""Original rotary positional embedding."""
def __init__(
@ -93,7 +93,7 @@ class RotaryEmbedding(nn.Module):
cache = torch.cat((cos, sin), dim=-1)
return cache
def _forward(
def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
@ -138,13 +138,15 @@ class RotaryEmbedding(nn.Module):
key = key.flatten(-2)
return query, key
def forward(
def forward_cuda(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops
self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
dtype=query.dtype)
# ops.rotary_embedding()/batched_rotary_embedding()