mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-29 03:47:05 +08:00
Add PyTorch-native implementation of custom layers (#1898)
This commit is contained in:
parent
5313c2cb8b
commit
9b294976a2
@ -1,9 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
from transformers.activations import get_activation
|
|
||||||
|
|
||||||
from vllm._C import ops
|
from vllm.model_executor.layers.activation import FastGELU, NewGELU, SiluAndMul
|
||||||
|
|
||||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||||
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
|
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
|
||||||
@ -11,11 +9,6 @@ D = [512, 4096, 5120, 13824] # Arbitrary values for testing
|
|||||||
SEEDS = [0]
|
SEEDS = [0]
|
||||||
|
|
||||||
|
|
||||||
def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor:
|
|
||||||
x1, x2 = x.chunk(chunks=2, dim=1)
|
|
||||||
return F.silu(x1) * x2
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||||
@pytest.mark.parametrize("d", D)
|
@pytest.mark.parametrize("d", D)
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@ -30,9 +23,9 @@ def test_silu_and_mul(
|
|||||||
torch.random.manual_seed(seed)
|
torch.random.manual_seed(seed)
|
||||||
torch.cuda.manual_seed(seed)
|
torch.cuda.manual_seed(seed)
|
||||||
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device="cuda")
|
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device="cuda")
|
||||||
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
|
layer = SiluAndMul()
|
||||||
ops.silu_and_mul(out, x)
|
out = layer(x)
|
||||||
ref_out = ref_silu_and_mul(x)
|
ref_out = layer._forward(x)
|
||||||
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
@ -50,9 +43,9 @@ def test_gelu_new(
|
|||||||
torch.random.manual_seed(seed)
|
torch.random.manual_seed(seed)
|
||||||
torch.cuda.manual_seed(seed)
|
torch.cuda.manual_seed(seed)
|
||||||
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
|
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
|
||||||
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
|
layer = NewGELU()
|
||||||
ops.gelu_new(out, x)
|
out = layer(x)
|
||||||
ref_out = get_activation("gelu_new")(x)
|
ref_out = layer._forward(x)
|
||||||
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
@ -69,7 +62,7 @@ def test_gelu_fast(
|
|||||||
torch.random.manual_seed(seed)
|
torch.random.manual_seed(seed)
|
||||||
torch.cuda.manual_seed(seed)
|
torch.cuda.manual_seed(seed)
|
||||||
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
|
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
|
||||||
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
|
layer = FastGELU()
|
||||||
ops.gelu_fast(out, x)
|
out = layer(x)
|
||||||
ref_out = get_activation("gelu_fast")(x)
|
ref_out = layer._forward(x)
|
||||||
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
||||||
|
|||||||
@ -1,58 +1,47 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from vllm._C import ops
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
|
||||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||||
HIDDEN_SIZES = [67, 768, 2048, 5120, 8192] # Arbitrary values for testing
|
|
||||||
NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
|
NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
|
||||||
|
HIDDEN_SIZES = [768, 5120, 8192] # Arbitrary values for testing
|
||||||
|
ADD_RESIDUAL = [False, True]
|
||||||
SEEDS = [0]
|
SEEDS = [0]
|
||||||
|
|
||||||
|
|
||||||
class RefRMSNorm(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
|
||||||
super().__init__()
|
|
||||||
weight = torch.empty(hidden_size)
|
|
||||||
weight.normal_(mean=1.0, std=0.1)
|
|
||||||
self.weight = nn.Parameter(weight)
|
|
||||||
self.variance_epsilon = eps
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
input_dtype = hidden_states.dtype
|
|
||||||
hidden_states = hidden_states.to(torch.float32)
|
|
||||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance +
|
|
||||||
self.variance_epsilon)
|
|
||||||
return self.weight * hidden_states.to(input_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||||
|
@pytest.mark.parametrize("add_residual", ADD_RESIDUAL)
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("seed", SEEDS)
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_rms_norm(
|
def test_rms_norm(
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
|
add_residual: bool,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
seed: int,
|
seed: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.random.manual_seed(seed)
|
torch.random.manual_seed(seed)
|
||||||
torch.cuda.manual_seed(seed)
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
scale = float(hidden_size**-0.5)
|
layer = RMSNorm(hidden_size).to(dtype).cuda()
|
||||||
x = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda")
|
layer.weight.data.normal_(mean=1.0, std=0.1)
|
||||||
x.uniform_(-scale, scale)
|
scale = 1 / (2 * hidden_size)
|
||||||
ref = RefRMSNorm(hidden_size).to(dtype).cuda()
|
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device="cuda")
|
||||||
|
x *= scale
|
||||||
|
residual = torch.randn_like(x) * scale if add_residual else None
|
||||||
|
|
||||||
out = torch.empty_like(x)
|
# NOTE(woosuk): The reference implementation should be executed first
|
||||||
ops.rms_norm(
|
# because the custom kernel is in-place.
|
||||||
out,
|
ref_out = layer._forward(x, residual)
|
||||||
x,
|
out = layer(x, residual)
|
||||||
ref.weight.data,
|
# NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
|
||||||
ref.variance_epsilon,
|
# numerical errors than other operators because they involve reductions.
|
||||||
)
|
# Therefore, we use a larger tolerance.
|
||||||
ref_out = ref(x)
|
if add_residual:
|
||||||
assert torch.allclose(out, ref_out, atol=1e-2, rtol=1e-5)
|
assert torch.allclose(out[0], ref_out[0], atol=1e-2, rtol=1e-2)
|
||||||
|
assert torch.allclose(out[1], ref_out[1], atol=1e-2, rtol=1e-2)
|
||||||
|
else:
|
||||||
|
assert torch.allclose(out, ref_out, atol=1e-2, rtol=1e-2)
|
||||||
|
|||||||
@ -1,105 +1,23 @@
|
|||||||
from typing import Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from vllm._C import ops
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
|
||||||
IS_NEOX_STYLE = [True, False]
|
IS_NEOX_STYLE = [True, False]
|
||||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||||
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||||
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
|
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
|
||||||
NUM_HEADS = [7, 12, 40, 52] # Arbitrary values for testing
|
NUM_HEADS = [7, 17] # Arbitrary values for testing
|
||||||
NUM_TOKENS = [11, 83, 2048] # Arbitrary values for testing
|
BATCH_SIZES = [1, 5] # Arbitrary values for testing
|
||||||
|
SEQ_LENS = [11, 8192] # Arbitrary values for testing
|
||||||
SEEDS = [0]
|
SEEDS = [0]
|
||||||
|
|
||||||
|
|
||||||
def rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
|
||||||
x1 = x[..., :x.shape[-1] // 2]
|
|
||||||
x2 = x[..., x.shape[-1] // 2:]
|
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
|
|
||||||
x1 = x[..., ::2]
|
|
||||||
x2 = x[..., 1::2]
|
|
||||||
x = torch.stack((-x2, x1), dim=-1)
|
|
||||||
return x.flatten(-2)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rope(
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
cos: torch.Tensor,
|
|
||||||
sin: torch.Tensor,
|
|
||||||
is_neox_style: bool,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
rotate_fn = rotate_neox if is_neox_style else rotate_gptj
|
|
||||||
q_embed = (q * cos) + (rotate_fn(q) * sin)
|
|
||||||
k_embed = (k * cos) + (rotate_fn(k) * sin)
|
|
||||||
return q_embed, k_embed
|
|
||||||
|
|
||||||
|
|
||||||
class RefRotaryEmbedding(nn.Module):
|
|
||||||
"""Reference implementation of rotary embedding."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim: int,
|
|
||||||
is_neox_style: bool,
|
|
||||||
max_position_embeddings: int = 8192,
|
|
||||||
base: int = 10000,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.rotary_dim = dim
|
|
||||||
self.is_neox_style = is_neox_style
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
|
|
||||||
# Create cos and sin embeddings.
|
|
||||||
inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim))
|
|
||||||
t = torch.arange(max_position_embeddings).float()
|
|
||||||
freqs = torch.einsum("i,j->ij", t, inv_freq.float())
|
|
||||||
if is_neox_style:
|
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
|
||||||
else:
|
|
||||||
emb = torch.repeat_interleave(freqs, 2, -1)
|
|
||||||
cos = emb.cos().to(dtype=inv_freq.dtype)
|
|
||||||
sin = emb.sin().to(dtype=inv_freq.dtype)
|
|
||||||
self.register_buffer("cos_cached", cos, persistent=False)
|
|
||||||
self.register_buffer("sin_cached", sin, persistent=False)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
positions: torch.Tensor, # [num_tokens]
|
|
||||||
query: torch.Tensor, # [num_tokens, num_heads, head_size]
|
|
||||||
key: torch.Tensor, # [num_tokens, num_heads, head_size]
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
query_rot = query[..., :self.rotary_dim]
|
|
||||||
query_pass = query[..., self.rotary_dim:]
|
|
||||||
key_rot = key[..., :self.rotary_dim]
|
|
||||||
key_pass = key[..., self.rotary_dim:]
|
|
||||||
|
|
||||||
query_rot = query_rot.transpose(0, 1)
|
|
||||||
key_rot = key_rot.transpose(0, 1)
|
|
||||||
cos = F.embedding(positions, self.cos_cached)
|
|
||||||
sin = F.embedding(positions, self.sin_cached)
|
|
||||||
|
|
||||||
query_rot, key_rot = apply_rope(query_rot, key_rot, cos, sin,
|
|
||||||
self.is_neox_style)
|
|
||||||
query_rot = query_rot.transpose(0, 1).contiguous()
|
|
||||||
key_rot = key_rot.transpose(0, 1).contiguous()
|
|
||||||
|
|
||||||
query = torch.cat((query_rot, query_pass), dim=-1)
|
|
||||||
key = torch.cat((key_rot, key_pass), dim=-1)
|
|
||||||
|
|
||||||
# Output query/key shape: [num_tokens, num_tokens, head_size]
|
|
||||||
return query, key
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||||
|
@pytest.mark.parametrize("seq_len", SEQ_LENS)
|
||||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
|
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
|
||||||
@ -108,7 +26,8 @@ class RefRotaryEmbedding(nn.Module):
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_rotary_embedding(
|
def test_rotary_embedding(
|
||||||
is_neox_style: bool,
|
is_neox_style: bool,
|
||||||
num_tokens: int,
|
batch_size: int,
|
||||||
|
seq_len: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
rotary_dim: Optional[int],
|
rotary_dim: Optional[int],
|
||||||
@ -122,53 +41,25 @@ def test_rotary_embedding(
|
|||||||
torch.random.manual_seed(seed)
|
torch.random.manual_seed(seed)
|
||||||
torch.cuda.manual_seed(seed)
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
positions = torch.randint(0, max_position, (num_tokens, ), device="cuda")
|
if rotary_dim is None:
|
||||||
query = torch.randn(num_tokens,
|
rotary_dim = head_size
|
||||||
|
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
|
||||||
|
rope = rope.to(dtype).cuda()
|
||||||
|
|
||||||
|
positions = torch.randint(0,
|
||||||
|
max_position, (batch_size, seq_len),
|
||||||
|
device="cuda")
|
||||||
|
query = torch.randn(batch_size,
|
||||||
|
seq_len,
|
||||||
num_heads * head_size,
|
num_heads * head_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device="cuda")
|
device="cuda")
|
||||||
key = torch.randn(num_tokens,
|
key = torch.randn_like(query)
|
||||||
num_heads * head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device="cuda")
|
|
||||||
|
|
||||||
# Create the rotary embedding.
|
|
||||||
inv_freq = 1.0 / (base**(
|
|
||||||
torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
|
|
||||||
t = torch.arange(max_position).float()
|
|
||||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
|
||||||
cos = freqs.cos()
|
|
||||||
sin = freqs.sin()
|
|
||||||
cos_sin_cache = torch.cat((cos, sin), dim=-1)
|
|
||||||
cos_sin_cache = cos_sin_cache.to(dtype=dtype, device="cuda")
|
|
||||||
|
|
||||||
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
|
|
||||||
out_query = query.clone()
|
|
||||||
out_key = key.clone()
|
|
||||||
ops.rotary_embedding(
|
|
||||||
positions,
|
|
||||||
out_query,
|
|
||||||
out_key,
|
|
||||||
head_size,
|
|
||||||
cos_sin_cache,
|
|
||||||
is_neox_style,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run the reference implementation.
|
|
||||||
ref_rotary_embedding = RefRotaryEmbedding(
|
|
||||||
dim=rotary_dim,
|
|
||||||
is_neox_style=is_neox_style,
|
|
||||||
max_position_embeddings=max_position,
|
|
||||||
base=base,
|
|
||||||
).to(dtype=dtype, device="cuda")
|
|
||||||
ref_query, ref_key = ref_rotary_embedding(
|
|
||||||
positions,
|
|
||||||
query.view(num_tokens, num_heads, head_size),
|
|
||||||
key.view(num_tokens, num_heads, head_size),
|
|
||||||
)
|
|
||||||
ref_query = ref_query.view(num_tokens, num_heads * head_size)
|
|
||||||
ref_key = ref_key.view(num_tokens, num_heads * head_size)
|
|
||||||
|
|
||||||
|
# NOTE(woosuk): The reference implementation should be executed first
|
||||||
|
# because the custom kernel is in-place.
|
||||||
|
ref_query, ref_key = rope._forward(positions, query, key)
|
||||||
|
out_query, out_key = rope.forward(positions, query, key)
|
||||||
# Compare the results.
|
# Compare the results.
|
||||||
assert torch.allclose(out_query, ref_query, atol=1e-5, rtol=1e-5)
|
assert torch.allclose(out_query, ref_query, atol=1e-5, rtol=1e-5)
|
||||||
assert torch.allclose(out_key, ref_key, atol=1e-5, rtol=1e-5)
|
assert torch.allclose(out_key, ref_key, atol=1e-5, rtol=1e-5)
|
||||||
|
|||||||
@ -1,8 +1,10 @@
|
|||||||
"""Custom activation functions."""
|
"""Custom activation functions."""
|
||||||
|
import math
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from vllm._C import ops
|
from vllm._C import ops
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
@ -22,6 +24,11 @@ class SiluAndMul(nn.Module):
|
|||||||
return: (batch_size, seq_len, d) or (num_tokens, d)
|
return: (batch_size, seq_len, d) or (num_tokens, d)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""PyTorch-native implementation equivalent to forward()."""
|
||||||
|
d = x.shape[-1] // 2
|
||||||
|
return F.silu(x[..., :d]) * x[..., d:]
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
d = x.shape[-1] // 2
|
d = x.shape[-1] // 2
|
||||||
output_shape = (x.shape[:-1] + (d, ))
|
output_shape = (x.shape[:-1] + (d, ))
|
||||||
@ -32,6 +39,12 @@ class SiluAndMul(nn.Module):
|
|||||||
|
|
||||||
class NewGELU(nn.Module):
|
class NewGELU(nn.Module):
|
||||||
|
|
||||||
|
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""PyTorch-native implementation equivalent to forward()."""
|
||||||
|
c = math.sqrt(2.0 / math.pi)
|
||||||
|
return 0.5 * x * (1.0 + torch.tanh(c *
|
||||||
|
(x + 0.044715 * torch.pow(x, 3.0))))
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
out = torch.empty_like(x)
|
out = torch.empty_like(x)
|
||||||
ops.gelu_new(out, x)
|
ops.gelu_new(out, x)
|
||||||
@ -40,6 +53,11 @@ class NewGELU(nn.Module):
|
|||||||
|
|
||||||
class FastGELU(nn.Module):
|
class FastGELU(nn.Module):
|
||||||
|
|
||||||
|
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""PyTorch-native implementation equivalent to forward()."""
|
||||||
|
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
|
||||||
|
(1.0 + 0.044715 * x * x)))
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
out = torch.empty_like(x)
|
out = torch.empty_like(x)
|
||||||
ops.gelu_fast(out, x)
|
ops.gelu_fast(out, x)
|
||||||
|
|||||||
@ -23,6 +23,26 @@ class RMSNorm(nn.Module):
|
|||||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||||
self.variance_epsilon = eps
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def _forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
"""PyTorch-native implementation equivalent to forward()."""
|
||||||
|
orig_dtype = x.dtype
|
||||||
|
x = x.to(torch.float32)
|
||||||
|
if residual is not None:
|
||||||
|
x = x + residual.to(torch.float32)
|
||||||
|
residual = x.to(orig_dtype)
|
||||||
|
|
||||||
|
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||||
|
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
x = x.to(orig_dtype) * self.weight
|
||||||
|
if residual is None:
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
return x, residual
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
|
|||||||
@ -30,6 +30,19 @@ import torch.nn as nn
|
|||||||
from vllm._C import ops
|
from vllm._C import ops
|
||||||
|
|
||||||
|
|
||||||
|
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x1 = x[..., :x.shape[-1] // 2]
|
||||||
|
x2 = x[..., x.shape[-1] // 2:]
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x1 = x[..., ::2]
|
||||||
|
x2 = x[..., 1::2]
|
||||||
|
x = torch.stack((-x2, x1), dim=-1)
|
||||||
|
return x.flatten(-2)
|
||||||
|
|
||||||
|
|
||||||
class RotaryEmbedding(nn.Module):
|
class RotaryEmbedding(nn.Module):
|
||||||
"""Original rotary positional embedding."""
|
"""Original rotary positional embedding."""
|
||||||
|
|
||||||
@ -81,6 +94,47 @@ class RotaryEmbedding(nn.Module):
|
|||||||
cache = torch.cat((cos, sin), dim=-1)
|
cache = torch.cat((cos, sin), dim=-1)
|
||||||
return cache
|
return cache
|
||||||
|
|
||||||
|
def _forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""PyTorch-native implementation equivalent to forward()."""
|
||||||
|
query = query.view(*query.shape[:-1], -1, self.head_size)
|
||||||
|
key = key.view(*key.shape[:-1], -1, self.head_size)
|
||||||
|
|
||||||
|
query_rot = query[..., :self.rotary_dim]
|
||||||
|
key_rot = key[..., :self.rotary_dim]
|
||||||
|
if self.rotary_dim < self.head_size:
|
||||||
|
query_pass = query[..., self.rotary_dim:]
|
||||||
|
key_pass = key[..., self.rotary_dim:]
|
||||||
|
|
||||||
|
cos_sin = self.cos_sin_cache[positions]
|
||||||
|
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||||
|
if self.is_neox_style:
|
||||||
|
# NOTE(woosuk): Here we assume that the positions tensor has the
|
||||||
|
# shape [batch_size, seq_len].
|
||||||
|
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
|
||||||
|
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
|
||||||
|
else:
|
||||||
|
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
||||||
|
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
||||||
|
|
||||||
|
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
|
||||||
|
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
|
||||||
|
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
|
||||||
|
|
||||||
|
if self.rotary_dim < self.head_size:
|
||||||
|
query = torch.cat((query_rot, query_pass), dim=-1)
|
||||||
|
key = torch.cat((key_rot, key_pass), dim=-1)
|
||||||
|
else:
|
||||||
|
query = query_rot
|
||||||
|
key = key_rot
|
||||||
|
query = query.flatten(-2)
|
||||||
|
key = key.flatten(-2)
|
||||||
|
return query, key
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user