diff --git a/inference/kernel.py b/inference/kernel.py index 22afc92..c09d1d9 100644 --- a/inference/kernel.py +++ b/inference/kernel.py @@ -6,6 +6,38 @@ import triton.language as tl from triton import Config +def convert_scale_to_ue8m0(scale_fp32: torch.Tensor) -> torch.Tensor: + """ + Converts a float32 scale tensor to UE8M0 format (uint8 exponent). + + Args: + scale_fp32 (torch.Tensor): Scale tensor in float32 format. + + Returns: + torch.Tensor: Scale tensor in uint8 format (exponent + 127). + """ + scale_fp32_clamped = torch.clamp(scale_fp32, min=1e-38) + exponent = torch.ceil(torch.log2(scale_fp32_clamped)) + exponent_biased = (exponent + 127).to(torch.int32) + exponent_biased = torch.clamp(exponent_biased, 0, 255) + return exponent_biased.to(torch.uint8) + + +def convert_scale_from_ue8m0(scale_uint8: torch.Tensor) -> torch.Tensor: + """ + Converts a UE8M0 format scale tensor (uint8 exponent) to float32. + + Args: + scale_uint8 (torch.Tensor): Scale tensor in uint8 format (exponent + 127). + + Returns: + torch.Tensor: Scale tensor in float32 format (2^exponent). + """ + exponent = scale_uint8.to(torch.int32) - 127 + scale_fp32 = torch.pow(2.0, exponent.to(torch.float32)) + return scale_fp32 + + @triton.jit def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr, scale_fmt: tl.constexpr): """ @@ -23,16 +55,19 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr, scale_fmt: t pid = tl.program_id(axis=0) offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) x = tl.load(x_ptr + offs).to(tl.float32) - amax = tl.max(tl.abs(x)) # reduction - amax = tl.maximum(amax, 1e-4) # clamp to 1e-4 + amax = tl.max(tl.abs(x)) + amax = tl.maximum(amax, 1e-4) s = amax / 448. if scale_fmt == "ue8m0": exp = tl.math.ceil(tl.math.log2(s)) s = tl.math.exp2(exp) + exp_int = exp.to(tl.int32) + 127 + tl.store(s_ptr + pid, exp_int.to(s_ptr.dtype.element_ty)) + else: + tl.store(s_ptr + pid, s) y = x / s y = y.to(y_ptr.dtype.element_ty) tl.store(y_ptr + offs, y) - tl.store(s_ptr + pid, s) def act_quant(x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None) -> Tuple[torch.Tensor, torch.Tensor]: @@ -51,14 +86,15 @@ def act_quant(x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = assert x.is_contiguous(), 'Input tensor must be contiguous' assert x.size(-1) % block_size == 0, f'Last dimension size must be divisible by block_size (block_size={block_size})' y = torch.empty_like(x, dtype=torch.float8_e4m3fn) - s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32) + scale_dtype = torch.uint8 if scale_fmt == "ue8m0" else torch.float32 + s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=scale_dtype) grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), ) act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size, scale_fmt=scale_fmt) return y, s @triton.jit -def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): +def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr, scale_fmt: tl.constexpr): """ Dequantizes weights using the provided scaling factors and stores the result. @@ -81,32 +117,39 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): offs = offs_m[:, None] * N + offs_n[None, :] mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) - s = tl.load(s_ptr + pid_m * n + pid_n) + if scale_fmt == "ue8m0": + s_uint8 = tl.load(s_ptr + pid_m * n + pid_n) + exp = s_uint8.to(tl.int32) - 127 + s = tl.math.exp2(exp.to(tl.float32)) + else: + s = tl.load(s_ptr + pid_m * n + pid_n) y = x * s tl.store(y_ptr + offs, y, mask=mask) -def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor: +def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None) -> torch.Tensor: """ - Dequantizes the given weight tensor using the provided scale tensor. + Dequantizes the input tensor `x` using the provided scaling factors `s`. Args: - x (torch.Tensor): The quantized weight tensor of shape (M, N). - s (torch.Tensor): The scale tensor of shape (M//block_size, N//block_size). - block_size (int, optional): The block size to use for dequantization. Defaults to 128. + x (torch.Tensor): The quantized input tensor. + s (torch.Tensor): The scaling factors. + block_size (int, optional): The size of the blocks to be used for dequantization. Default is 128. + scale_fmt (Optional[str], optional): The format of the scale. Default is None. Returns: - torch.Tensor: The dequantized weight tensor of the same shape as `x`. - - Raises: - AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2. + torch.Tensor: The dequantized tensor. """ assert x.is_contiguous() and s.is_contiguous(), 'Input tensors must be contiguous' assert x.dim() == 2 and s.dim() == 2, 'Input tensors must have 2 dimensions' + if scale_fmt is None: + scale_fmt = "ue8m0" if s.dtype == torch.uint8 else None + if scale_fmt == "ue8m0" and s.dtype != torch.uint8: + s = convert_scale_to_ue8m0(s) M, N = x.size() y = torch.empty_like(x, dtype=torch.get_default_dtype()) grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE'])) - weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) + weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size, scale_fmt=scale_fmt) return y @@ -122,22 +165,26 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr, M, N: tl.constexpr, K: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr): + BLOCK_SIZE_K: tl.constexpr, + scale_fmt: tl.constexpr): """ Performs a matrix multiplication operation on FP8 matrices with scaling factors. Args: - a_ptr (tl.tensor): Pointer to the first input matrix A. - b_ptr (tl.tensor): Pointer to the second input matrix B. + a_ptr (tl.tensor): Pointer to the first input matrix A (FP8). + b_ptr (tl.tensor): Pointer to the second input matrix B (FP8). c_ptr (tl.tensor): Pointer to the output matrix C. a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A. + (uint8 if scale_fmt=="ue8m0", float32 otherwise) b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B. + (uint8 if scale_fmt=="ue8m0", float32 otherwise) M (int): Number of rows in matrix A and C. N (tl.constexpr): Number of columns in matrix B and C. K (tl.constexpr): Number of columns in matrix A and rows in matrix B. BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension. BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension. BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension. + scale_fmt (tl.constexpr): Scale format ("ue8m0" for uint8 exponent format, None for float32). Returns: None @@ -157,9 +204,20 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr, for i in range(k): a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0) - a_s = tl.load(a_s_ptrs) - b_s = tl.load(b_s_ptrs) - accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + + if scale_fmt == "ue8m0": + a_s_uint8 = tl.load(a_s_ptrs) + b_s_uint8 = tl.load(b_s_ptrs) + a_exp = a_s_uint8.to(tl.int32) - 127 + b_exp = b_s_uint8.to(tl.int32) - 127 + combined_exp = a_exp[:, None] + b_exp[None, :] + scale = tl.math.exp2(combined_exp.to(tl.float32)) + accumulator += tl.dot(a, b) * scale + else: + a_s = tl.load(a_s_ptrs) + b_s = tl.load(b_s_ptrs) + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K b_ptrs += BLOCK_SIZE_K a_s_ptrs += 1 @@ -172,7 +230,7 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr, tl.store(c_ptrs, c, mask=mask) -def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor): +def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor, scale_fmt: Optional[str] = None): """ Perform a matrix multiplication using FP8 precision. @@ -181,16 +239,28 @@ def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Ten a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous. b (torch.Tensor): The second input matrix, must be contiguous. b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous. + scale_fmt (Optional[str], optional): The format of the scale. Default is None. Returns: torch.Tensor: The result of the matrix multiplication. """ assert a.is_contiguous() and b.is_contiguous(), 'Input tensors must be contiguous' assert a_s.is_contiguous() and b_s.is_contiguous(), 'Scaling factor tensors must be contiguous' + if scale_fmt is None: + if a_s.dtype == torch.uint8 or b_s.dtype == torch.uint8: + scale_fmt = "ue8m0" + else: + scale_fmt = None + if scale_fmt == "ue8m0": + if a_s.dtype != torch.uint8: + a_s = convert_scale_to_ue8m0(a_s) + if b_s.dtype != torch.uint8: + b_s = convert_scale_to_ue8m0(b_s) + K = a.size(-1) M = a.numel() // K N = b.size(0) c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N'])) - fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K) + fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K, scale_fmt=scale_fmt) return c diff --git a/inference/model.py b/inference/model.py index 8868499..70936ef 100644 --- a/inference/model.py +++ b/inference/model.py @@ -131,33 +131,24 @@ class ParallelEmbedding(nn.Module): def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, scale_fmt: Optional[str] = None) -> torch.Tensor: """ Applies a linear transformation to the incoming data: y = xA^T + b. - This function supports specialized implementations based on quantization - and tensor formats. Args: x (torch.Tensor): The input tensor. - weight (torch.Tensor): The weight tensor. It may be quantized and - requires dequantization for certain cases. + weight (torch.Tensor): The weight tensor. bias (Optional[torch.Tensor]): The bias tensor to be added. Default is None. + scale_fmt (Optional[str], optional): The format of the scale. Default is None. Returns: - torch.Tensor: The result of the linear transformation, which may involve - quantization-aware computations depending on the input parameters. - - Notes: - - If `weight` is quantized (e.g., `element_size() == 1`), a dequantized version - is used for computation. - - If `gemm_impl == "bf16"`, dequantization and a `bf16` GEMM operation are applied. - - For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation. + torch.Tensor: The result of the linear transformation. """ if weight.element_size() > 1: return F.linear(x, weight, bias) elif gemm_impl == "bf16": - weight = weight_dequant(weight, weight.scale) + weight = weight_dequant(weight, weight.scale, scale_fmt=scale_fmt) return F.linear(x, weight, bias) else: x, scale = act_quant(x, block_size, scale_fmt) - y = fp8_gemm(x, scale, weight, weight.scale) + y = fp8_gemm(x, scale, weight, weight.scale, scale_fmt) if bias is not None: y += bias return y @@ -478,7 +469,7 @@ class MLA(nn.Module): self.v_cache[:bsz, start_pos:end_pos] = v scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale else: - wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size) + wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size, Linear.scale_fmt) wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank) q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim]) self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)