mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 03:44:58 +08:00
[PERF] [Qwen3-next] Speed up gated RMSNorm (#26207)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
4ca204055e
commit
82e64c7a20
388
tests/kernels/test_fla_layernorm_guard.py
Normal file
388
tests/kernels/test_fla_layernorm_guard.py
Normal file
@ -0,0 +1,388 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.model_executor.layers.fla.ops.layernorm_guard import (
|
||||
layer_norm_fwd,
|
||||
layernorm_fn,
|
||||
rms_norm_ref,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def layer_norm_ref(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
"""Reference implementation for both layer norm and RMS norm."""
|
||||
if is_rms_norm:
|
||||
# Use the imported rms_norm_ref for RMS norm cases
|
||||
return rms_norm_ref(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
z=z,
|
||||
eps=eps,
|
||||
group_size=group_size,
|
||||
norm_before_gate=norm_before_gate,
|
||||
upcast=True,
|
||||
)
|
||||
|
||||
# Layer norm implementation
|
||||
dtype = x.dtype
|
||||
x = x.float()
|
||||
weight = weight.float()
|
||||
bias = bias.float() if bias is not None else None
|
||||
z = z.float() if z is not None else None
|
||||
|
||||
if z is not None and not norm_before_gate:
|
||||
x = x * F.silu(z)
|
||||
|
||||
if group_size is None:
|
||||
# Layer norm: subtract mean
|
||||
mean = x.mean(dim=-1, keepdim=True)
|
||||
var = ((x - mean).square()).mean(dim=-1, keepdim=True)
|
||||
rstd = 1 / torch.sqrt(var + eps)
|
||||
out = (x - mean) * rstd * weight
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
else:
|
||||
# Group norm
|
||||
from einops import rearrange
|
||||
|
||||
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
|
||||
mean = x_group.mean(dim=-1, keepdim=True)
|
||||
var = ((x_group - mean).square()).mean(dim=-1, keepdim=True)
|
||||
rstd = 1 / torch.sqrt(var + eps)
|
||||
x_group = (x_group - mean) * rstd
|
||||
out = rearrange(x_group, "... g d -> ... (g d)") * weight
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
|
||||
if z is not None and norm_before_gate:
|
||||
out *= F.silu(z)
|
||||
|
||||
return out.to(dtype)
|
||||
|
||||
|
||||
DTYPES = [torch.bfloat16, torch.float32]
|
||||
# Test various M sizes to ensure rows_per_block logic works correctly
|
||||
NUM_TOKENS = [
|
||||
1,
|
||||
7,
|
||||
16,
|
||||
63,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
2048,
|
||||
4096,
|
||||
5789,
|
||||
8189,
|
||||
8191,
|
||||
16383,
|
||||
32767,
|
||||
]
|
||||
HIDDEN_SIZES = [64, 128, 256, 1024]
|
||||
GROUP_SIZES = [None, 64, 128] # None means full hidden size
|
||||
NORM_BEFORE_GATE = [True, False]
|
||||
IS_RMS_NORM = [True, False]
|
||||
SEEDS = [0, 42]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("is_rms_norm", IS_RMS_NORM)
|
||||
@torch.inference_mode()
|
||||
def test_layer_norm_fwd_basic(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
is_rms_norm: bool,
|
||||
) -> None:
|
||||
"""Test basic layer norm forward pass without z (gate) tensor."""
|
||||
current_platform.seed_everything(seed)
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
# Create inputs
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device)
|
||||
weight = torch.randn(hidden_size, dtype=dtype, device=device)
|
||||
bias = None if is_rms_norm else torch.randn(hidden_size, dtype=dtype, device=device)
|
||||
eps = 1e-6
|
||||
|
||||
# Run the triton kernel
|
||||
out, mean, rstd = layer_norm_fwd(
|
||||
x, weight, bias, eps, z=None, is_rms_norm=is_rms_norm
|
||||
)
|
||||
|
||||
# Run reference implementation
|
||||
ref_out = layer_norm_ref(x, weight, bias, z=None, eps=eps, is_rms_norm=is_rms_norm)
|
||||
|
||||
# Check outputs
|
||||
assert out.shape == x.shape
|
||||
assert out.dtype == x.dtype
|
||||
torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)
|
||||
|
||||
# Check mean and rstd shapes
|
||||
if not is_rms_norm:
|
||||
assert mean.shape == (num_tokens,)
|
||||
assert rstd.shape == (num_tokens,)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", [128, 256, 1024])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("norm_before_gate", NORM_BEFORE_GATE)
|
||||
@pytest.mark.parametrize("is_rms_norm", IS_RMS_NORM)
|
||||
@torch.inference_mode()
|
||||
def test_layer_norm_fwd_with_gate(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
norm_before_gate: bool,
|
||||
is_rms_norm: bool,
|
||||
) -> None:
|
||||
"""Test layer norm forward pass with z (gate) tensor."""
|
||||
current_platform.seed_everything(42)
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
# Create inputs
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device)
|
||||
z = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device)
|
||||
weight = torch.randn(hidden_size, dtype=dtype, device=device)
|
||||
bias = None if is_rms_norm else torch.randn(hidden_size, dtype=dtype, device=device)
|
||||
eps = 1e-6
|
||||
|
||||
# Run the triton kernel
|
||||
out, mean, rstd = layer_norm_fwd(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
z=z,
|
||||
norm_before_gate=norm_before_gate,
|
||||
is_rms_norm=is_rms_norm,
|
||||
)
|
||||
|
||||
# Run reference implementation
|
||||
ref_out = layer_norm_ref(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
z=z,
|
||||
eps=eps,
|
||||
norm_before_gate=norm_before_gate,
|
||||
is_rms_norm=is_rms_norm,
|
||||
)
|
||||
|
||||
# Check outputs
|
||||
assert out.shape == x.shape
|
||||
assert out.dtype == x.dtype
|
||||
torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [128, 512])
|
||||
@pytest.mark.parametrize("hidden_size", [512, 1024])
|
||||
@pytest.mark.parametrize("group_size", [64, 128, 256])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("is_rms_norm", IS_RMS_NORM)
|
||||
@torch.inference_mode()
|
||||
def test_layer_norm_fwd_with_groups(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
group_size: int,
|
||||
dtype: torch.dtype,
|
||||
is_rms_norm: bool,
|
||||
) -> None:
|
||||
"""Test layer norm forward pass with group normalization."""
|
||||
if hidden_size % group_size != 0:
|
||||
pytest.skip(
|
||||
f"hidden_size {hidden_size} not divisible by group_size {group_size}"
|
||||
)
|
||||
|
||||
current_platform.seed_everything(42)
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
# Create inputs
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device)
|
||||
weight = torch.randn(hidden_size, dtype=dtype, device=device)
|
||||
bias = None if is_rms_norm else torch.randn(hidden_size, dtype=dtype, device=device)
|
||||
eps = 1e-6
|
||||
|
||||
ngroups = hidden_size // group_size
|
||||
|
||||
# Run the triton kernel
|
||||
out, mean, rstd = layer_norm_fwd(
|
||||
x, weight, bias, eps, z=None, group_size=group_size, is_rms_norm=is_rms_norm
|
||||
)
|
||||
|
||||
# Run reference implementation
|
||||
ref_out = layer_norm_ref(
|
||||
x, weight, bias, z=None, eps=eps, group_size=group_size, is_rms_norm=is_rms_norm
|
||||
)
|
||||
|
||||
# Check outputs
|
||||
assert out.shape == x.shape
|
||||
assert out.dtype == x.dtype
|
||||
torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)
|
||||
|
||||
# Check mean and rstd shapes for groups
|
||||
if not is_rms_norm:
|
||||
assert mean.shape == (ngroups * num_tokens,)
|
||||
assert rstd.shape == (ngroups * num_tokens,)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [7, 63, 128, 513, 1024, 2049])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@torch.inference_mode()
|
||||
def test_layer_norm_rows_per_block(
|
||||
num_tokens: int,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
"""Test that rows_per_block logic works correctly for various M sizes."""
|
||||
current_platform.seed_everything(42)
|
||||
device = torch.device("cuda:0")
|
||||
hidden_size = 1024
|
||||
|
||||
# Create inputs
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device)
|
||||
weight = torch.randn(hidden_size, dtype=dtype, device=device)
|
||||
bias = torch.randn(hidden_size, dtype=dtype, device=device)
|
||||
eps = 1e-6
|
||||
|
||||
# Run the triton kernel
|
||||
out, mean, rstd = layer_norm_fwd(x, weight, bias, eps, z=None, is_rms_norm=False)
|
||||
|
||||
# Run reference implementation
|
||||
ref_out = layer_norm_ref(x, weight, bias, z=None, eps=eps, is_rms_norm=False)
|
||||
|
||||
# Check outputs
|
||||
torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@torch.inference_mode()
|
||||
def test_strided_input(dtype: torch.dtype) -> None:
|
||||
"""Test that the kernel handles non-contiguous (strided)
|
||||
inputs correctly."""
|
||||
current_platform.seed_everything(42)
|
||||
device = torch.device("cuda:0")
|
||||
num_tokens = 128
|
||||
hidden_size = 1024
|
||||
|
||||
# Create a larger tensor and take a strided slice
|
||||
x_large = torch.randn(num_tokens, hidden_size * 2, dtype=dtype, device=device)
|
||||
x = x_large[:, :hidden_size]
|
||||
|
||||
# Make it contiguous for the kernel
|
||||
x_contiguous = x.contiguous()
|
||||
|
||||
weight = torch.randn(hidden_size, dtype=dtype, device=device)
|
||||
bias = torch.randn(hidden_size, dtype=dtype, device=device)
|
||||
eps = 1e-6
|
||||
|
||||
# Run the triton kernel with contiguous input
|
||||
out, mean, rstd = layer_norm_fwd(
|
||||
x_contiguous, weight, bias, eps, z=None, is_rms_norm=False
|
||||
)
|
||||
|
||||
# Run reference implementation
|
||||
ref_out = layer_norm_ref(
|
||||
x_contiguous, weight, bias, z=None, eps=eps, is_rms_norm=False
|
||||
)
|
||||
|
||||
# Check outputs
|
||||
torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [1, 128, 2048])
|
||||
@pytest.mark.parametrize("hidden_size", [768, 4096])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@torch.inference_mode()
|
||||
def test_output_buffer_provided(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
"""Test that the kernel works when an output buffer is provided."""
|
||||
current_platform.seed_everything(42)
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
# Create inputs
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device)
|
||||
weight = torch.randn(hidden_size, dtype=dtype, device=device)
|
||||
bias = torch.randn(hidden_size, dtype=dtype, device=device)
|
||||
eps = 1e-6
|
||||
|
||||
# Pre-allocate output buffer
|
||||
out_buffer = torch.empty_like(x)
|
||||
|
||||
# Run the triton kernel with provided output
|
||||
out, mean, rstd = layer_norm_fwd(
|
||||
x, weight, bias, eps, z=None, out=out_buffer, is_rms_norm=False
|
||||
)
|
||||
|
||||
# Check that the provided buffer was used
|
||||
assert out.data_ptr() == out_buffer.data_ptr()
|
||||
|
||||
# Run reference implementation
|
||||
ref_out = layer_norm_ref(x, weight, bias, z=None, eps=eps, is_rms_norm=False)
|
||||
|
||||
# Check outputs
|
||||
torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"shape",
|
||||
[
|
||||
(4, 16, 1024), # 3D tensor
|
||||
(2, 8, 512, 256), # 4D tensor
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@torch.inference_mode()
|
||||
def test_multidimensional_input(
|
||||
shape: tuple,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
"""Test that the autograd function handles multidimensional inputs."""
|
||||
current_platform.seed_everything(42)
|
||||
device = torch.device("cuda:0")
|
||||
hidden_size = shape[-1]
|
||||
|
||||
# Create inputs
|
||||
x = torch.randn(*shape, dtype=dtype, device=device)
|
||||
weight = torch.randn(hidden_size, dtype=dtype, device=device)
|
||||
bias = torch.randn(hidden_size, dtype=dtype, device=device)
|
||||
eps = 1e-6
|
||||
|
||||
# Run through autograd function
|
||||
out = layernorm_fn(x, weight, bias, z=None, eps=eps)
|
||||
|
||||
# Run reference implementation
|
||||
ref_out = layer_norm_ref(x, weight, bias, z=None, eps=eps, is_rms_norm=False)
|
||||
|
||||
# Check outputs
|
||||
assert out.shape == x.shape
|
||||
torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run a quick smoke test
|
||||
test_layer_norm_fwd_basic(128, 1024, torch.float16, 42, False)
|
||||
test_layer_norm_fwd_with_gate(128, 1024, torch.float16, True, False)
|
||||
test_layer_norm_rows_per_block(513, torch.float16)
|
||||
print("All smoke tests passed!")
|
||||
@ -13,6 +13,7 @@
|
||||
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
||||
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@ -21,6 +22,7 @@ import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import cdiv, next_power_of_2
|
||||
|
||||
from .utils import input_guard
|
||||
|
||||
@ -76,55 +78,103 @@ def layer_norm_fwd_kernel(
|
||||
stride_y_row,
|
||||
stride_z_row,
|
||||
M, # number of rows in X
|
||||
N, # number of columns in X
|
||||
N: tl.constexpr, # number of columns in X
|
||||
eps, # epsilon to avoid division by zero
|
||||
BLOCK_N: tl.constexpr,
|
||||
ROWS_PER_BLOCK: tl.constexpr,
|
||||
HAS_BIAS: tl.constexpr,
|
||||
HAS_Z: tl.constexpr,
|
||||
NORM_BEFORE_GATE: tl.constexpr,
|
||||
IS_RMS_NORM: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the row of X and Y it should compute.
|
||||
row = tl.program_id(0)
|
||||
# Map the program id to the starting row of X and Y it should compute.
|
||||
row_start = tl.program_id(0) * ROWS_PER_BLOCK
|
||||
group = tl.program_id(1)
|
||||
X += row * stride_x_row + group * N
|
||||
Y += row * stride_y_row + group * N
|
||||
if HAS_Z:
|
||||
Z += row * stride_z_row + group * N
|
||||
if not IS_RMS_NORM:
|
||||
Mean += group * M
|
||||
Rstd += group * M
|
||||
W += group * N
|
||||
if HAS_BIAS:
|
||||
B += group * N
|
||||
# Compute mean and variance
|
||||
|
||||
# Create 2D tile: [ROWS_PER_BLOCK, BLOCK_N]
|
||||
rows = row_start + tl.arange(0, ROWS_PER_BLOCK)
|
||||
cols = tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
||||
|
||||
# Compute offsets for 2D tile
|
||||
row_offsets = rows[:, None] * stride_x_row
|
||||
col_offsets = cols[None, :] + group * N
|
||||
|
||||
# Base pointers
|
||||
X_base = X + row_offsets + col_offsets
|
||||
Y_base = Y + rows[:, None] * stride_y_row + col_offsets
|
||||
|
||||
# Create mask for valid rows and columns
|
||||
row_mask = rows[:, None] < M
|
||||
col_mask = cols[None, :] < N
|
||||
mask = row_mask & col_mask
|
||||
|
||||
# Load input data with 2D tile
|
||||
x = tl.load(X_base, mask=mask, other=0.0).to(tl.float32)
|
||||
|
||||
if HAS_Z and not NORM_BEFORE_GATE:
|
||||
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
|
||||
Z_base = Z + rows[:, None] * stride_z_row + col_offsets
|
||||
z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32)
|
||||
x *= z * tl.sigmoid(z)
|
||||
|
||||
# Compute mean and variance per row (reduce along axis 1)
|
||||
if not IS_RMS_NORM:
|
||||
mean = tl.sum(x, axis=0) / N
|
||||
tl.store(Mean + row, mean)
|
||||
xbar = tl.where(cols < N, x - mean, 0.0)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
mean = tl.sum(x, axis=1) / N # Shape: [ROWS_PER_BLOCK]
|
||||
# Store mean for each row
|
||||
mean_offsets = group * M + rows
|
||||
mean_mask = rows < M
|
||||
tl.store(Mean + mean_offsets, mean, mask=mean_mask)
|
||||
# Broadcast mean back to 2D for subtraction
|
||||
xbar = tl.where(mask, x - mean[:, None], 0.0)
|
||||
var = tl.sum(xbar * xbar, axis=1) / N # Shape: [ROWS_PER_BLOCK]
|
||||
else:
|
||||
xbar = tl.where(cols < N, x, 0.0)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
rstd = 1 / tl.sqrt(var + eps)
|
||||
tl.store(Rstd + row, rstd)
|
||||
# Normalize and apply linear transformation
|
||||
mask = cols < N
|
||||
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
||||
xbar = tl.where(mask, x, 0.0)
|
||||
var = tl.sum(xbar * xbar, axis=1) / N # Shape: [ROWS_PER_BLOCK]
|
||||
mean = 0.0 # Placeholder for RMS norm
|
||||
|
||||
rstd = tl.rsqrt(var + eps) # Shape: [ROWS_PER_BLOCK]
|
||||
|
||||
# Store rstd for each row
|
||||
rstd_offsets = group * M + rows
|
||||
rstd_mask = rows < M
|
||||
tl.store(Rstd + rstd_offsets, rstd, mask=rstd_mask)
|
||||
|
||||
# Load weights and biases (broadcast across rows)
|
||||
w_offsets = cols + group * N
|
||||
w_mask = cols < N
|
||||
w = tl.load(W + w_offsets, mask=w_mask, other=0.0).to(tl.float32)
|
||||
|
||||
if HAS_BIAS:
|
||||
b = tl.load(B + cols, mask=mask).to(tl.float32)
|
||||
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
||||
y = x_hat * w + b if HAS_BIAS else x_hat * w
|
||||
b = tl.load(B + w_offsets, mask=w_mask, other=0.0).to(tl.float32)
|
||||
|
||||
# Normalize and apply linear transformation
|
||||
if not IS_RMS_NORM:
|
||||
x_hat = (x - mean[:, None]) * rstd[:, None]
|
||||
else:
|
||||
x_hat = x * rstd[:, None]
|
||||
|
||||
y = x_hat * w[None, :] + b[None, :] if HAS_BIAS else x_hat * w[None, :]
|
||||
|
||||
if HAS_Z and NORM_BEFORE_GATE:
|
||||
z = tl.load(Z + cols, mask=mask).to(tl.float32)
|
||||
Z_base = Z + rows[:, None] * stride_z_row + col_offsets
|
||||
z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32)
|
||||
y *= z * tl.sigmoid(z)
|
||||
|
||||
# Write output
|
||||
tl.store(Y + cols, y, mask=mask)
|
||||
tl.store(Y_base, y, mask=mask)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _get_sm_count(device: torch.device) -> int:
|
||||
"""Get and cache the SM count for a given device."""
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
return props.multi_processor_count
|
||||
|
||||
|
||||
def calc_rows_per_block(M: int, device: torch.device) -> int:
|
||||
sm_count = _get_sm_count(device)
|
||||
rows_per_block = next_power_of_2(cdiv(M, 2 * sm_count))
|
||||
rows_per_block = min(rows_per_block, 4)
|
||||
return rows_per_block
|
||||
|
||||
|
||||
def layer_norm_fwd(
|
||||
@ -171,7 +221,10 @@ def layer_norm_fwd(
|
||||
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_N // 256, 1), 8)
|
||||
grid = (M, ngroups)
|
||||
# Calculate rows per block based on SM count
|
||||
rows_per_block = calc_rows_per_block(M, x.device)
|
||||
# Update grid to use rows_per_block
|
||||
grid = (cdiv(M, rows_per_block), ngroups)
|
||||
layer_norm_fwd_kernel[grid](
|
||||
x,
|
||||
out,
|
||||
@ -187,6 +240,7 @@ def layer_norm_fwd(
|
||||
group_size,
|
||||
eps,
|
||||
BLOCK_N=BLOCK_N,
|
||||
ROWS_PER_BLOCK=rows_per_block,
|
||||
NORM_BEFORE_GATE=norm_before_gate,
|
||||
IS_RMS_NORM=is_rms_norm,
|
||||
num_warps=num_warps,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user