Compare commits

...

7 Commits

Author SHA1 Message Date
Jason
ed11ac644a
Merge f10ff9c26237af8a96a7b3eff70d37d43609f7f4 into 9b4e9788e4a3a731f7567338ed15d3ec549ce03b 2025-09-01 14:01:14 +00:00
GeeeekExplorer
9b4e9788e4 Merge pull request #969 from youkaichao/rmsnorm
act_quant_kernel
2025-08-28 11:24:26 +08:00
youkaichao
adecc0efbe fix rmsnorm and act_quant_kernel 2025-08-27 17:12:13 +08:00
youkaichao
82f6008c8c
fix act_quant_kernel (#968)
Signed-off-by: youkaichao <youkaichao@gmail.com>
2025-08-27 16:23:30 +08:00
youkaichao
b15f0dbbbe
support scale_fmt=ue8m0 (#964)
* support scale_fmt=ue8m0

* keep improving

Signed-off-by: youkaichao <youkaichao@gmail.com>

* keep improving

Signed-off-by: youkaichao <youkaichao@gmail.com>

* add clamp min of 1e-4

Signed-off-by: youkaichao <youkaichao@gmail.com>

* rename config

Signed-off-by: youkaichao <youkaichao@gmail.com>

---------

Signed-off-by: youkaichao <youkaichao@gmail.com>
2025-08-27 15:30:21 +08:00
Xingkai Yu
4592be48c0
fp32 gate bias 2025-08-26 17:39:07 +08:00
messagezsl
f10ff9c262 Update kernel.py 2025-02-08 16:07:47 +08:00
3 changed files with 70 additions and 14 deletions

View File

@ -0,0 +1,23 @@
{
"vocab_size": 129280,
"dim": 7168,
"inter_dim": 18432,
"moe_inter_dim": 2048,
"n_layers": 61,
"n_dense_layers": 3,
"n_heads": 128,
"n_routed_experts": 256,
"n_shared_experts": 1,
"n_activated_experts": 8,
"n_expert_groups": 8,
"n_limited_groups": 4,
"route_scale": 2.5,
"score_func": "sigmoid",
"q_lora_rank": 1536,
"kv_lora_rank": 512,
"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,
"v_head_dim": 128,
"dtype": "fp8",
"scale_fmt": "ue8m0"
}

View File

@ -1,4 +1,4 @@
from typing import Tuple
from typing import Tuple, Optional
import torch
import triton
@ -7,7 +7,7 @@ from triton import Config
@triton.jit
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr, scale_fmt: tl.constexpr):
"""
Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.
@ -23,21 +23,26 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + offs).to(tl.float32)
s = tl.max(tl.abs(x)) / 448.
amax = tl.max(tl.abs(x)) # reduction
amax = tl.maximum(amax, 1e-4) # clamp to 1e-4
s = amax / 448.
if scale_fmt == "ue8m0":
exp = tl.math.ceil(tl.math.log2(s))
s = tl.math.exp2(exp)
y = x / s
y = y.to(y_ptr.dtype.element_ty)
tl.store(y_ptr + offs, y)
tl.store(s_ptr + pid, s)
def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
def act_quant(x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantizes the input tensor `x` using block-wise quantization.
Args:
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
scale_fmt (Optional[str], optional): The format of the scale. Default is None.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- The quantized tensor with dtype `torch.float8_e4m3fn`.
@ -48,7 +53,7 @@ def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, tor
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size, scale_fmt=scale_fmt)
return y, s
@ -106,7 +111,7 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t
fp8_gemm_configs = [
Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8)
Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128, "GROUP_SIZE_M": 8}, num_stages=num_stages, num_warps=8)
for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6]
]
@ -137,9 +142,32 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
Returns:
None
"""
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
# See above `L2 Cache Optimizations` section for details.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
k = tl.cdiv(K, BLOCK_SIZE_K)
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
# See above `Pointer Arithmetic` section for details
offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
@ -186,6 +214,7 @@ def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Ten
M = a.numel() // K
N = b.size(0)
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
# 1D launch kernel where each block gets its own program.
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
return c

View File

@ -25,6 +25,7 @@ class ModelArgs:
max_batch_size (int): Maximum batch size.
max_seq_len (int): Maximum sequence length.
dtype (Literal["bf16", "fp8"]): Data type for computations.
scale_fmt (Optional[str]): Format for quantization scale.
vocab_size (int): Vocabulary size.
dim (int): Model dimension.
inter_dim (int): Intermediate dimension for MLP layers.
@ -54,6 +55,7 @@ class ModelArgs:
max_batch_size: int = 8
max_seq_len: int = 4096 * 4
dtype: Literal["bf16", "fp8"] = "bf16"
scale_fmt: Optional[str] = None
vocab_size: int = 102400
dim: int = 2048
inter_dim: int = 10944
@ -126,7 +128,7 @@ class ParallelEmbedding(nn.Module):
return y
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, scale_fmt: Optional[str] = None) -> torch.Tensor:
"""
Applies a linear transformation to the incoming data: y = xA^T + b.
This function supports specialized implementations based on quantization
@ -154,7 +156,7 @@ def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] =
weight = weight_dequant(weight, weight.scale)
return F.linear(x, weight, bias)
else:
x, scale = act_quant(x, block_size)
x, scale = act_quant(x, block_size, scale_fmt)
y = fp8_gemm(x, scale, weight, weight.scale)
if bias is not None:
y += bias
@ -172,6 +174,7 @@ class Linear(nn.Module):
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
"""
dtype = torch.bfloat16
scale_fmt: Optional[str] = None
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
super().__init__()
@ -199,7 +202,7 @@ class Linear(nn.Module):
Returns:
torch.Tensor: Transformed tensor after linear computation.
"""
return linear(x, self.weight, self.bias)
return linear(x, self.weight, self.bias, self.scale_fmt)
class ColumnParallelLinear(Linear):
@ -558,7 +561,7 @@ class Gate(nn.Module):
self.score_func = args.score_func
self.route_scale = args.route_scale
self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None
self.bias = nn.Parameter(torch.empty(args.n_routed_experts, dtype=torch.float32)) if self.dim == 7168 else None
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
@ -755,6 +758,7 @@ class Transformer(nn.Module):
world_size = dist.get_world_size() if dist.is_initialized() else 1
rank = dist.get_rank() if dist.is_initialized() else 0
Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
Linear.scale_fmt = args.scale_fmt
super().__init__()
self.max_seq_len = args.max_seq_len
self.embed = ParallelEmbedding(args.vocab_size, args.dim)