From 4592be48c07f036b32ef971474068aebc489e3e7 Mon Sep 17 00:00:00 2001 From: Xingkai Yu <38156925+GeeeekExplorer@users.noreply.github.com> Date: Tue, 26 Aug 2025 17:39:07 +0800 Subject: [PATCH 1/4] fp32 gate bias --- inference/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/inference/model.py b/inference/model.py index c143e97..7539a68 100644 --- a/inference/model.py +++ b/inference/model.py @@ -558,7 +558,7 @@ class Gate(nn.Module): self.score_func = args.score_func self.route_scale = args.route_scale self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim)) - self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None + self.bias = nn.Parameter(torch.empty(args.n_routed_experts, dtype=torch.float32)) if self.dim == 7168 else None def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ From b15f0dbbbe6a4bc403306175698439ef380f5fb5 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 27 Aug 2025 15:30:21 +0800 Subject: [PATCH 2/4] support scale_fmt=ue8m0 (#964) * support scale_fmt=ue8m0 * keep improving Signed-off-by: youkaichao * keep improving Signed-off-by: youkaichao * add clamp min of 1e-4 Signed-off-by: youkaichao * rename config Signed-off-by: youkaichao --------- Signed-off-by: youkaichao --- inference/configs/config_v3.1.json | 23 +++++++++++++++++++++++ inference/kernel.py | 17 +++++++++++------ inference/model.py | 10 +++++++--- 3 files changed, 41 insertions(+), 9 deletions(-) create mode 100644 inference/configs/config_v3.1.json diff --git a/inference/configs/config_v3.1.json b/inference/configs/config_v3.1.json new file mode 100644 index 0000000..091d4cc --- /dev/null +++ b/inference/configs/config_v3.1.json @@ -0,0 +1,23 @@ +{ + "vocab_size": 129280, + "dim": 7168, + "inter_dim": 18432, + "moe_inter_dim": 2048, + "n_layers": 61, + "n_dense_layers": 3, + "n_heads": 128, + "n_routed_experts": 256, + "n_shared_experts": 1, + "n_activated_experts": 8, + "n_expert_groups": 8, + "n_limited_groups": 4, + "route_scale": 2.5, + "score_func": "sigmoid", + "q_lora_rank": 1536, + "kv_lora_rank": 512, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "v_head_dim": 128, + "dtype": "fp8", + "scale_fmt": "ue8m0" +} \ No newline at end of file diff --git a/inference/kernel.py b/inference/kernel.py index ba18dca..e12ad48 100644 --- a/inference/kernel.py +++ b/inference/kernel.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Tuple, Optional import torch import triton @@ -7,7 +7,7 @@ from triton import Config @triton.jit -def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): +def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr, scale_fmt: tl.constexpr): """ Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`. @@ -23,21 +23,26 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis=0) offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) x = tl.load(x_ptr + offs).to(tl.float32) - s = tl.max(tl.abs(x)) / 448. + amax = tl.max(tl.abs(x)) + amax = tl.min(amax, 1e-4) + s = amax / 448. + if scale_fmt == "ue8m0": + exp = tl.math.ceil(tl.math.log2(s)) + s = tl.math.exp2(exp) y = x / s y = y.to(y_ptr.dtype.element_ty) tl.store(y_ptr + offs, y) tl.store(s_ptr + pid, s) -def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: +def act_quant(x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None) -> Tuple[torch.Tensor, torch.Tensor]: """ Quantizes the input tensor `x` using block-wise quantization. Args: x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`. block_size (int, optional): The size of the blocks to be used for quantization. Default is 128. - + scale_fmt (Optional[str], optional): The format of the scale. Default is None. Returns: Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - The quantized tensor with dtype `torch.float8_e4m3fn`. @@ -48,7 +53,7 @@ def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, tor y = torch.empty_like(x, dtype=torch.float8_e4m3fn) s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32) grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), ) - act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size) + act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size, scale_fmt=scale_fmt) return y, s diff --git a/inference/model.py b/inference/model.py index 7539a68..8868499 100644 --- a/inference/model.py +++ b/inference/model.py @@ -25,6 +25,7 @@ class ModelArgs: max_batch_size (int): Maximum batch size. max_seq_len (int): Maximum sequence length. dtype (Literal["bf16", "fp8"]): Data type for computations. + scale_fmt (Optional[str]): Format for quantization scale. vocab_size (int): Vocabulary size. dim (int): Model dimension. inter_dim (int): Intermediate dimension for MLP layers. @@ -54,6 +55,7 @@ class ModelArgs: max_batch_size: int = 8 max_seq_len: int = 4096 * 4 dtype: Literal["bf16", "fp8"] = "bf16" + scale_fmt: Optional[str] = None vocab_size: int = 102400 dim: int = 2048 inter_dim: int = 10944 @@ -126,7 +128,7 @@ class ParallelEmbedding(nn.Module): return y -def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: +def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, scale_fmt: Optional[str] = None) -> torch.Tensor: """ Applies a linear transformation to the incoming data: y = xA^T + b. This function supports specialized implementations based on quantization @@ -154,7 +156,7 @@ def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = weight = weight_dequant(weight, weight.scale) return F.linear(x, weight, bias) else: - x, scale = act_quant(x, block_size) + x, scale = act_quant(x, block_size, scale_fmt) y = fp8_gemm(x, scale, weight, weight.scale) if bias is not None: y += bias @@ -172,6 +174,7 @@ class Linear(nn.Module): dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. """ dtype = torch.bfloat16 + scale_fmt: Optional[str] = None def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): super().__init__() @@ -199,7 +202,7 @@ class Linear(nn.Module): Returns: torch.Tensor: Transformed tensor after linear computation. """ - return linear(x, self.weight, self.bias) + return linear(x, self.weight, self.bias, self.scale_fmt) class ColumnParallelLinear(Linear): @@ -755,6 +758,7 @@ class Transformer(nn.Module): world_size = dist.get_world_size() if dist.is_initialized() else 1 rank = dist.get_rank() if dist.is_initialized() else 0 Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16 + Linear.scale_fmt = args.scale_fmt super().__init__() self.max_seq_len = args.max_seq_len self.embed = ParallelEmbedding(args.vocab_size, args.dim) From 82f6008c8c6a69459d0d949cef75b8fc70096460 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 27 Aug 2025 16:23:30 +0800 Subject: [PATCH 3/4] fix act_quant_kernel (#968) Signed-off-by: youkaichao --- inference/kernel.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/inference/kernel.py b/inference/kernel.py index e12ad48..57832ad 100644 --- a/inference/kernel.py +++ b/inference/kernel.py @@ -23,8 +23,7 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr, scale_fmt: t pid = tl.program_id(axis=0) offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) x = tl.load(x_ptr + offs).to(tl.float32) - amax = tl.max(tl.abs(x)) - amax = tl.min(amax, 1e-4) + amax = tl.max(tl.abs(x), 1e-4) s = amax / 448. if scale_fmt == "ue8m0": exp = tl.math.ceil(tl.math.log2(s)) From adecc0efbe2fda18945734168fce6e0df0d804c3 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 27 Aug 2025 17:12:13 +0800 Subject: [PATCH 4/4] fix rmsnorm and act_quant_kernel --- inference/kernel.py | 3 ++- inference/model.py | 7 ++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/inference/kernel.py b/inference/kernel.py index 57832ad..22afc92 100644 --- a/inference/kernel.py +++ b/inference/kernel.py @@ -23,7 +23,8 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr, scale_fmt: t pid = tl.program_id(axis=0) offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) x = tl.load(x_ptr + offs).to(tl.float32) - amax = tl.max(tl.abs(x), 1e-4) + amax = tl.max(tl.abs(x)) # reduction + amax = tl.maximum(amax, 1e-4) # clamp to 1e-4 s = amax / 448. if scale_fmt == "ue8m0": exp = tl.math.ceil(tl.math.log2(s)) diff --git a/inference/model.py b/inference/model.py index 8868499..9ce8168 100644 --- a/inference/model.py +++ b/inference/model.py @@ -291,7 +291,12 @@ class RMSNorm(nn.Module): Returns: torch.Tensor: Normalized tensor with the same shape as input. """ - return F.rms_norm(x, (self.dim,), self.weight, self.eps) + dtype = x.dtype + # make sure rms norm is computed in fp32 + x = x.to(torch.float32) + var = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(var + self.eps) + return (self.weight * x).to(dtype) def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor: