mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 08:55:46 +08:00
Optimize GeGLU layer in Gemma (#2975)
This commit is contained in:
parent
93dc5a2870
commit
fd5dcc5c81
@ -2,19 +2,16 @@
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ T silu(const T& x) {
|
||||
// x * sigmoid(x)
|
||||
return (T) (((float) x) / (1.0f + expf((float) -x)));
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
__global__ void silu_and_mul_kernel(
|
||||
// Activation and gating kernel template.
|
||||
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
|
||||
__global__ void act_and_mul_kernel(
|
||||
scalar_t* __restrict__ out, // [..., d]
|
||||
const scalar_t* __restrict__ input, // [..., 2, d]
|
||||
const int d) {
|
||||
@ -22,32 +19,58 @@ __global__ void silu_and_mul_kernel(
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
|
||||
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
|
||||
out[token_idx * d + idx] = silu(x) * y;
|
||||
out[token_idx * d + idx] = ACT_FN(x) * y;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ T silu_kernel(const T& x) {
|
||||
// x * sigmoid(x)
|
||||
return (T) (((float) x) / (1.0f + expf((float) -x)));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ T gelu_kernel(const T& x) {
|
||||
// Equivalent to PyTorch GELU with 'none' approximation.
|
||||
// Refer to:
|
||||
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L38
|
||||
const float f = (float) x;
|
||||
constexpr float ALPHA = M_SQRT1_2;
|
||||
return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA)));
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
// Launch activation and gating kernel.
|
||||
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
|
||||
int d = input.size(-1) / 2; \
|
||||
int64_t num_tokens = input.numel() / input.size(-1); \
|
||||
dim3 grid(num_tokens); \
|
||||
dim3 block(std::min(d, 1024)); \
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), \
|
||||
"act_and_mul_kernel", \
|
||||
[&] { \
|
||||
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
|
||||
out.data_ptr<scalar_t>(), \
|
||||
input.data_ptr<scalar_t>(), \
|
||||
d); \
|
||||
});
|
||||
|
||||
void silu_and_mul(
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., 2 * d]
|
||||
{
|
||||
int64_t num_tokens = input.numel() / input.size(-1);
|
||||
int d = input.size(-1) / 2;
|
||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
|
||||
}
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(d, 1024));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(),
|
||||
"silu_and_mul_kernel",
|
||||
[&] {
|
||||
vllm::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
d);
|
||||
});
|
||||
void gelu_and_mul(
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., 2 * d]
|
||||
{
|
||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
@ -57,6 +57,10 @@ void silu_and_mul(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
|
||||
void gelu_and_mul(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
|
||||
void gelu_new(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
|
||||
@ -22,6 +22,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
"silu_and_mul",
|
||||
&silu_and_mul,
|
||||
"Activation function used in SwiGLU.");
|
||||
ops.def(
|
||||
"gelu_and_mul",
|
||||
&gelu_and_mul,
|
||||
"Activation function used in GeGLU.");
|
||||
ops.def(
|
||||
"gelu_new",
|
||||
&gelu_new,
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.activation import FastGELU, NewGELU, SiluAndMul
|
||||
from vllm.model_executor.layers.activation import (FastGELU, GeluAndMul,
|
||||
NewGELU, SiluAndMul)
|
||||
from allclose_default import get_default_atol, get_default_rtol
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
@ -13,13 +16,15 @@ CUDA_DEVICES = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("activation", [SiluAndMul, GeluAndMul])
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("d", D)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_silu_and_mul(
|
||||
def test_act_and_mul(
|
||||
activation: Type[torch.nn.Module],
|
||||
num_tokens: int,
|
||||
d: int,
|
||||
dtype: torch.dtype,
|
||||
@ -31,22 +36,23 @@ def test_silu_and_mul(
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
|
||||
layer = SiluAndMul()
|
||||
layer = activation()
|
||||
out = layer(x)
|
||||
ref_out = layer._forward(x)
|
||||
assert torch.allclose(out,
|
||||
ref_out,
|
||||
atol=get_default_atol(out),
|
||||
rtol=get_default_rtol(out))
|
||||
# The SiLU and GELU implementations are equivalent to the native PyTorch
|
||||
# implementations, so we can do exact comparison.
|
||||
assert torch.allclose(out, ref_out, atol=0.0, rtol=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("activation", [FastGELU, NewGELU])
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("d", D)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_gelu_new(
|
||||
def test_activation(
|
||||
activation: Type[torch.nn.Module],
|
||||
num_tokens: int,
|
||||
d: int,
|
||||
dtype: torch.dtype,
|
||||
@ -58,33 +64,7 @@ def test_gelu_new(
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
x = torch.randn(num_tokens, d, dtype=dtype)
|
||||
layer = NewGELU()
|
||||
out = layer(x)
|
||||
ref_out = layer._forward(x)
|
||||
assert torch.allclose(out,
|
||||
ref_out,
|
||||
atol=get_default_atol(out),
|
||||
rtol=get_default_rtol(out))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("d", D)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_gelu_fast(
|
||||
num_tokens: int,
|
||||
d: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
x = torch.randn(num_tokens, d, dtype=dtype)
|
||||
layer = FastGELU()
|
||||
layer = activation()
|
||||
out = layer(x)
|
||||
ref_out = layer._forward(x)
|
||||
assert torch.allclose(out,
|
||||
|
||||
@ -37,6 +37,29 @@ class SiluAndMul(nn.Module):
|
||||
return out
|
||||
|
||||
|
||||
class GeluAndMul(nn.Module):
|
||||
"""An activation function for GeGLU.
|
||||
|
||||
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
|
||||
|
||||
Shapes:
|
||||
x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
|
||||
return: (batch_size, seq_len, d) or (num_tokens, d)
|
||||
"""
|
||||
|
||||
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
d = x.shape[-1] // 2
|
||||
return F.gelu(x[..., :d]) * x[..., d:]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
ops.gelu_and_mul(out, x)
|
||||
return out
|
||||
|
||||
|
||||
class NewGELU(nn.Module):
|
||||
|
||||
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@ -21,10 +21,11 @@ from torch import nn
|
||||
from transformers import GemmaConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import GeluAndMul
|
||||
from vllm.model_executor.layers.attention import PagedAttention
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
@ -50,27 +51,21 @@ class GemmaMLP(nn.Module):
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_proj = ColumnParallelLinear(hidden_size,
|
||||
intermediate_size,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
self.up_proj = ColumnParallelLinear(hidden_size,
|
||||
intermediate_size,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
self.act_fn = nn.GELU()
|
||||
self.act_fn = GeluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate, _ = self.gate_proj(x)
|
||||
gate = self.act_fn(gate)
|
||||
up, _ = self.up_proj(x)
|
||||
fuse = gate * up
|
||||
outputs, _ = self.down_proj(fuse)
|
||||
return outputs
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class GemmaAttention(nn.Module):
|
||||
@ -294,6 +289,8 @@ class GemmaForCausalLM(nn.Module):
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params = set()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user