support scale_fmt=ue8m0

This commit is contained in:
youkaichao 2025-08-26 17:48:06 +08:00
parent f6e34dd267
commit 3745dc5ab6
3 changed files with 25 additions and 8 deletions

View File

@ -8,7 +8,7 @@ import torch.distributed as dist
from transformers import AutoTokenizer
from safetensors.torch import load_model
from model import Transformer, ModelArgs
from model import Transformer, ModelArgs, set_global_args
def sample(logits, temperature: float = 1.0):
@ -110,7 +110,12 @@ def main(
torch.set_num_threads(8)
torch.manual_seed(965)
with open(config) as f:
args = ModelArgs(**json.load(f))
config_dict = json.load(f)
args = ModelArgs(**config_dict)
quantization_config = config_dict.get("quantization_config", None)
if quantization_config is not None:
args.scale_fmt = quantization_config.get("scale_fmt", None)
set_global_args(args)
print(args)
with torch.device("cuda"):
model = Transformer(args)

View File

@ -1,4 +1,4 @@
from typing import Tuple
from typing import Tuple, Optional
import torch
import triton
@ -7,7 +7,7 @@ from triton import Config
@triton.jit
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr, scale_fmt: tl.constexpr):
"""
Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.
@ -24,20 +24,23 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + offs).to(tl.float32)
s = tl.max(tl.abs(x)) / 448.
if scale_fmt == "ue8m0":
exp = tl.math.ceil(tl.math.log2(s))
s = tl.math.exp2(exp)
y = x / s
y = y.to(y_ptr.dtype.element_ty)
tl.store(y_ptr + offs, y)
tl.store(s_ptr + pid, s)
def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
def act_quant(x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantizes the input tensor `x` using block-wise quantization.
Args:
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
scale_fmt (Optional[str], optional): The format of the scale. Default is None.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- The quantized tensor with dtype `torch.float8_e4m3fn`.
@ -48,7 +51,7 @@ def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, tor
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size, scale_fmt=scale_fmt)
return y, s

View File

@ -25,6 +25,7 @@ class ModelArgs:
max_batch_size (int): Maximum batch size.
max_seq_len (int): Maximum sequence length.
dtype (Literal["bf16", "fp8"]): Data type for computations.
scale_fmt (Optional[str]): Format for quantization scale.
vocab_size (int): Vocabulary size.
dim (int): Model dimension.
inter_dim (int): Intermediate dimension for MLP layers.
@ -54,6 +55,7 @@ class ModelArgs:
max_batch_size: int = 8
max_seq_len: int = 4096 * 4
dtype: Literal["bf16", "fp8"] = "bf16"
scale_fmt: Optional[str] = None
vocab_size: int = 102400
dim: int = 2048
inter_dim: int = 10944
@ -83,6 +85,12 @@ class ModelArgs:
beta_slow: int = 1
mscale: float = 1.
global_args: Optional[ModelArgs] = None
def set_global_args(args: ModelArgs):
global global_args
global_args = args
class ParallelEmbedding(nn.Module):
"""
@ -154,7 +162,8 @@ def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] =
weight = weight_dequant(weight, weight.scale)
return F.linear(x, weight, bias)
else:
x, scale = act_quant(x, block_size)
assert global_args is not None, "global_args is required for fp8_gemm"
x, scale = act_quant(x, block_size, global_args.scale_fmt)
y = fp8_gemm(x, scale, weight, weight.scale)
if bias is not None:
y += bias