diff --git a/inference/generate.py b/inference/generate.py index 7e9bffe..7a811e6 100644 --- a/inference/generate.py +++ b/inference/generate.py @@ -8,7 +8,7 @@ import torch.distributed as dist from transformers import AutoTokenizer from safetensors.torch import load_model -from model import Transformer, ModelArgs +from model import Transformer, ModelArgs, set_global_args def sample(logits, temperature: float = 1.0): @@ -110,7 +110,12 @@ def main( torch.set_num_threads(8) torch.manual_seed(965) with open(config) as f: - args = ModelArgs(**json.load(f)) + config_dict = json.load(f) + args = ModelArgs(**config_dict) + quantization_config = config_dict.get("quantization_config", None) + if quantization_config is not None: + args.scale_fmt = quantization_config.get("scale_fmt", None) + set_global_args(args) print(args) with torch.device("cuda"): model = Transformer(args) diff --git a/inference/kernel.py b/inference/kernel.py index ba18dca..af69daa 100644 --- a/inference/kernel.py +++ b/inference/kernel.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Tuple, Optional import torch import triton @@ -7,7 +7,7 @@ from triton import Config @triton.jit -def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): +def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr, scale_fmt: tl.constexpr): """ Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`. @@ -24,20 +24,23 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) x = tl.load(x_ptr + offs).to(tl.float32) s = tl.max(tl.abs(x)) / 448. + if scale_fmt == "ue8m0": + exp = tl.math.ceil(tl.math.log2(s)) + s = tl.math.exp2(exp) y = x / s y = y.to(y_ptr.dtype.element_ty) tl.store(y_ptr + offs, y) tl.store(s_ptr + pid, s) -def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: +def act_quant(x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None) -> Tuple[torch.Tensor, torch.Tensor]: """ Quantizes the input tensor `x` using block-wise quantization. Args: x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`. block_size (int, optional): The size of the blocks to be used for quantization. Default is 128. - + scale_fmt (Optional[str], optional): The format of the scale. Default is None. Returns: Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - The quantized tensor with dtype `torch.float8_e4m3fn`. @@ -48,7 +51,7 @@ def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, tor y = torch.empty_like(x, dtype=torch.float8_e4m3fn) s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32) grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), ) - act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size) + act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size, scale_fmt=scale_fmt) return y, s diff --git a/inference/model.py b/inference/model.py index c143e97..2e0af14 100644 --- a/inference/model.py +++ b/inference/model.py @@ -25,6 +25,7 @@ class ModelArgs: max_batch_size (int): Maximum batch size. max_seq_len (int): Maximum sequence length. dtype (Literal["bf16", "fp8"]): Data type for computations. + scale_fmt (Optional[str]): Format for quantization scale. vocab_size (int): Vocabulary size. dim (int): Model dimension. inter_dim (int): Intermediate dimension for MLP layers. @@ -54,6 +55,7 @@ class ModelArgs: max_batch_size: int = 8 max_seq_len: int = 4096 * 4 dtype: Literal["bf16", "fp8"] = "bf16" + scale_fmt: Optional[str] = None vocab_size: int = 102400 dim: int = 2048 inter_dim: int = 10944 @@ -83,6 +85,12 @@ class ModelArgs: beta_slow: int = 1 mscale: float = 1. +global_args: Optional[ModelArgs] = None + +def set_global_args(args: ModelArgs): + global global_args + global_args = args + class ParallelEmbedding(nn.Module): """ @@ -154,7 +162,8 @@ def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = weight = weight_dequant(weight, weight.scale) return F.linear(x, weight, bias) else: - x, scale = act_quant(x, block_size) + assert global_args is not None, "global_args is required for fp8_gemm" + x, scale = act_quant(x, block_size, global_args.scale_fmt) y = fp8_gemm(x, scale, weight, weight.scale) if bias is not None: y += bias