Add PyTorch-native implementation of custom layers (#1898)

This commit is contained in:
Woosuk Kwon 2023-12-02 21:18:40 -08:00 committed by GitHub
parent 5313c2cb8b
commit 9b294976a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 149 additions and 184 deletions

View File

@ -1,9 +1,7 @@
import pytest import pytest
import torch import torch
import torch.nn.functional as F
from transformers.activations import get_activation
from vllm._C import ops from vllm.model_executor.layers.activation import FastGELU, NewGELU, SiluAndMul
DTYPES = [torch.half, torch.bfloat16, torch.float] DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
@ -11,11 +9,6 @@ D = [512, 4096, 5120, 13824] # Arbitrary values for testing
SEEDS = [0] SEEDS = [0]
def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor:
x1, x2 = x.chunk(chunks=2, dim=1)
return F.silu(x1) * x2
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D) @pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@ -30,9 +23,9 @@ def test_silu_and_mul(
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device="cuda") x = torch.randn(num_tokens, 2 * d, dtype=dtype, device="cuda")
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda") layer = SiluAndMul()
ops.silu_and_mul(out, x) out = layer(x)
ref_out = ref_silu_and_mul(x) ref_out = layer._forward(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
@ -50,9 +43,9 @@ def test_gelu_new(
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda") x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda") layer = NewGELU()
ops.gelu_new(out, x) out = layer(x)
ref_out = get_activation("gelu_new")(x) ref_out = layer._forward(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
@ -69,7 +62,7 @@ def test_gelu_fast(
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda") x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda") layer = FastGELU()
ops.gelu_fast(out, x) out = layer(x)
ref_out = get_activation("gelu_fast")(x) ref_out = layer._forward(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)

View File

@ -1,58 +1,47 @@
import pytest import pytest
import torch import torch
import torch.nn as nn
from vllm._C import ops from vllm.model_executor.layers.layernorm import RMSNorm
DTYPES = [torch.half, torch.bfloat16, torch.float] DTYPES = [torch.half, torch.bfloat16, torch.float]
HIDDEN_SIZES = [67, 768, 2048, 5120, 8192] # Arbitrary values for testing
NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
HIDDEN_SIZES = [768, 5120, 8192] # Arbitrary values for testing
ADD_RESIDUAL = [False, True]
SEEDS = [0] SEEDS = [0]
class RefRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
weight = torch.empty(hidden_size)
weight.normal_(mean=1.0, std=0.1)
self.weight = nn.Parameter(weight)
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance +
self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("add_residual", ADD_RESIDUAL)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode() @torch.inference_mode()
def test_rms_norm( def test_rms_norm(
num_tokens: int, num_tokens: int,
hidden_size: int, hidden_size: int,
add_residual: bool,
dtype: torch.dtype, dtype: torch.dtype,
seed: int, seed: int,
) -> None: ) -> None:
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
scale = float(hidden_size**-0.5) layer = RMSNorm(hidden_size).to(dtype).cuda()
x = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda") layer.weight.data.normal_(mean=1.0, std=0.1)
x.uniform_(-scale, scale) scale = 1 / (2 * hidden_size)
ref = RefRMSNorm(hidden_size).to(dtype).cuda() x = torch.randn(num_tokens, hidden_size, dtype=dtype, device="cuda")
x *= scale
residual = torch.randn_like(x) * scale if add_residual else None
out = torch.empty_like(x) # NOTE(woosuk): The reference implementation should be executed first
ops.rms_norm( # because the custom kernel is in-place.
out, ref_out = layer._forward(x, residual)
x, out = layer(x, residual)
ref.weight.data, # NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
ref.variance_epsilon, # numerical errors than other operators because they involve reductions.
) # Therefore, we use a larger tolerance.
ref_out = ref(x) if add_residual:
assert torch.allclose(out, ref_out, atol=1e-2, rtol=1e-5) assert torch.allclose(out[0], ref_out[0], atol=1e-2, rtol=1e-2)
assert torch.allclose(out[1], ref_out[1], atol=1e-2, rtol=1e-2)
else:
assert torch.allclose(out, ref_out, atol=1e-2, rtol=1e-2)

View File

@ -1,105 +1,23 @@
from typing import Optional, Tuple from typing import Optional
import pytest import pytest
import torch import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm._C import ops from vllm.model_executor.layers.rotary_embedding import get_rope
IS_NEOX_STYLE = [True, False] IS_NEOX_STYLE = [True, False]
DTYPES = [torch.half, torch.bfloat16, torch.float] DTYPES = [torch.half, torch.bfloat16, torch.float]
HEAD_SIZES = [64, 80, 96, 112, 128, 256] HEAD_SIZES = [64, 80, 96, 112, 128, 256]
ROTARY_DIMS = [None, 32] # None means rotary dim == head size ROTARY_DIMS = [None, 32] # None means rotary dim == head size
NUM_HEADS = [7, 12, 40, 52] # Arbitrary values for testing NUM_HEADS = [7, 17] # Arbitrary values for testing
NUM_TOKENS = [11, 83, 2048] # Arbitrary values for testing BATCH_SIZES = [1, 5] # Arbitrary values for testing
SEQ_LENS = [11, 8192] # Arbitrary values for testing
SEEDS = [0] SEEDS = [0]
def rotate_neox(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., ::2]
x2 = x[..., 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2)
def apply_rope(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_neox_style: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
rotate_fn = rotate_neox if is_neox_style else rotate_gptj
q_embed = (q * cos) + (rotate_fn(q) * sin)
k_embed = (k * cos) + (rotate_fn(k) * sin)
return q_embed, k_embed
class RefRotaryEmbedding(nn.Module):
"""Reference implementation of rotary embedding."""
def __init__(
self,
dim: int,
is_neox_style: bool,
max_position_embeddings: int = 8192,
base: int = 10000,
) -> None:
super().__init__()
self.rotary_dim = dim
self.is_neox_style = is_neox_style
self.max_position_embeddings = max_position_embeddings
# Create cos and sin embeddings.
inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim))
t = torch.arange(max_position_embeddings).float()
freqs = torch.einsum("i,j->ij", t, inv_freq.float())
if is_neox_style:
emb = torch.cat((freqs, freqs), dim=-1)
else:
emb = torch.repeat_interleave(freqs, 2, -1)
cos = emb.cos().to(dtype=inv_freq.dtype)
sin = emb.sin().to(dtype=inv_freq.dtype)
self.register_buffer("cos_cached", cos, persistent=False)
self.register_buffer("sin_cached", sin, persistent=False)
def forward(
self,
positions: torch.Tensor, # [num_tokens]
query: torch.Tensor, # [num_tokens, num_heads, head_size]
key: torch.Tensor, # [num_tokens, num_heads, head_size]
) -> Tuple[torch.Tensor, torch.Tensor]:
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
query_rot = query_rot.transpose(0, 1)
key_rot = key_rot.transpose(0, 1)
cos = F.embedding(positions, self.cos_cached)
sin = F.embedding(positions, self.sin_cached)
query_rot, key_rot = apply_rope(query_rot, key_rot, cos, sin,
self.is_neox_style)
query_rot = query_rot.transpose(0, 1).contiguous()
key_rot = key_rot.transpose(0, 1).contiguous()
query = torch.cat((query_rot, query_pass), dim=-1)
key = torch.cat((key_rot, key_pass), dim=-1)
# Output query/key shape: [num_tokens, num_tokens, head_size]
return query, key
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) @pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@ -108,7 +26,8 @@ class RefRotaryEmbedding(nn.Module):
@torch.inference_mode() @torch.inference_mode()
def test_rotary_embedding( def test_rotary_embedding(
is_neox_style: bool, is_neox_style: bool,
num_tokens: int, batch_size: int,
seq_len: int,
num_heads: int, num_heads: int,
head_size: int, head_size: int,
rotary_dim: Optional[int], rotary_dim: Optional[int],
@ -122,53 +41,25 @@ def test_rotary_embedding(
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
positions = torch.randint(0, max_position, (num_tokens, ), device="cuda") if rotary_dim is None:
query = torch.randn(num_tokens, rotary_dim = head_size
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
rope = rope.to(dtype).cuda()
positions = torch.randint(0,
max_position, (batch_size, seq_len),
device="cuda")
query = torch.randn(batch_size,
seq_len,
num_heads * head_size, num_heads * head_size,
dtype=dtype, dtype=dtype,
device="cuda") device="cuda")
key = torch.randn(num_tokens, key = torch.randn_like(query)
num_heads * head_size,
dtype=dtype,
device="cuda")
# Create the rotary embedding.
inv_freq = 1.0 / (base**(
torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
t = torch.arange(max_position).float()
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cos_sin_cache = torch.cat((cos, sin), dim=-1)
cos_sin_cache = cos_sin_cache.to(dtype=dtype, device="cuda")
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
out_query = query.clone()
out_key = key.clone()
ops.rotary_embedding(
positions,
out_query,
out_key,
head_size,
cos_sin_cache,
is_neox_style,
)
# Run the reference implementation.
ref_rotary_embedding = RefRotaryEmbedding(
dim=rotary_dim,
is_neox_style=is_neox_style,
max_position_embeddings=max_position,
base=base,
).to(dtype=dtype, device="cuda")
ref_query, ref_key = ref_rotary_embedding(
positions,
query.view(num_tokens, num_heads, head_size),
key.view(num_tokens, num_heads, head_size),
)
ref_query = ref_query.view(num_tokens, num_heads * head_size)
ref_key = ref_key.view(num_tokens, num_heads * head_size)
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_query, ref_key = rope._forward(positions, query, key)
out_query, out_key = rope.forward(positions, query, key)
# Compare the results. # Compare the results.
assert torch.allclose(out_query, ref_query, atol=1e-5, rtol=1e-5) assert torch.allclose(out_query, ref_query, atol=1e-5, rtol=1e-5)
assert torch.allclose(out_key, ref_key, atol=1e-5, rtol=1e-5) assert torch.allclose(out_key, ref_key, atol=1e-5, rtol=1e-5)

View File

@ -1,8 +1,10 @@
"""Custom activation functions.""" """Custom activation functions."""
import math
from typing import Optional from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from vllm._C import ops from vllm._C import ops
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
@ -22,6 +24,11 @@ class SiluAndMul(nn.Module):
return: (batch_size, seq_len, d) or (num_tokens, d) return: (batch_size, seq_len, d) or (num_tokens, d)
""" """
def _forward(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2 d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, )) output_shape = (x.shape[:-1] + (d, ))
@ -32,6 +39,12 @@ class SiluAndMul(nn.Module):
class NewGELU(nn.Module): class NewGELU(nn.Module):
def _forward(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
c = math.sqrt(2.0 / math.pi)
return 0.5 * x * (1.0 + torch.tanh(c *
(x + 0.044715 * torch.pow(x, 3.0))))
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x) out = torch.empty_like(x)
ops.gelu_new(out, x) ops.gelu_new(out, x)
@ -40,6 +53,11 @@ class NewGELU(nn.Module):
class FastGELU(nn.Module): class FastGELU(nn.Module):
def _forward(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
(1.0 + 0.044715 * x * x)))
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x) out = torch.empty_like(x)
ops.gelu_fast(out, x) ops.gelu_fast(out, x)

View File

@ -23,6 +23,26 @@ class RMSNorm(nn.Module):
self.weight = nn.Parameter(torch.ones(hidden_size)) self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps self.variance_epsilon = eps
def _forward(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype) * self.weight
if residual is None:
return x
else:
return x, residual
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,

View File

@ -30,6 +30,19 @@ import torch.nn as nn
from vllm._C import ops from vllm._C import ops
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., ::2]
x2 = x[..., 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2)
class RotaryEmbedding(nn.Module): class RotaryEmbedding(nn.Module):
"""Original rotary positional embedding.""" """Original rotary positional embedding."""
@ -81,6 +94,47 @@ class RotaryEmbedding(nn.Module):
cache = torch.cat((cos, sin), dim=-1) cache = torch.cat((cos, sin), dim=-1)
return cache return cache
def _forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward()."""
query = query.view(*query.shape[:-1], -1, self.head_size)
key = key.view(*key.shape[:-1], -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
key_rot = key[..., :self.rotary_dim]
if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:]
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
if self.is_neox_style:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
if self.rotary_dim < self.head_size:
query = torch.cat((query_rot, query_pass), dim=-1)
key = torch.cat((key_rot, key_pass), dim=-1)
else:
query = query_rot
key = key_rot
query = query.flatten(-2)
key = key.flatten(-2)
return query, key
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,