mirror of
https://git.datalinker.icu/deepseek-ai/DeepSeek-V3.git
synced 2025-12-09 04:44:28 +08:00
feat: implement UE8M0 scale format support for FP8 inference
This commit is contained in:
parent
9b4e9788e4
commit
73fe98d4b1
@ -6,6 +6,38 @@ import triton.language as tl
|
||||
from triton import Config
|
||||
|
||||
|
||||
def convert_scale_to_ue8m0(scale_fp32: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Converts a float32 scale tensor to UE8M0 format (uint8 exponent).
|
||||
|
||||
Args:
|
||||
scale_fp32 (torch.Tensor): Scale tensor in float32 format.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Scale tensor in uint8 format (exponent + 127).
|
||||
"""
|
||||
scale_fp32_clamped = torch.clamp(scale_fp32, min=1e-38)
|
||||
exponent = torch.ceil(torch.log2(scale_fp32_clamped))
|
||||
exponent_biased = (exponent + 127).to(torch.int32)
|
||||
exponent_biased = torch.clamp(exponent_biased, 0, 255)
|
||||
return exponent_biased.to(torch.uint8)
|
||||
|
||||
|
||||
def convert_scale_from_ue8m0(scale_uint8: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Converts a UE8M0 format scale tensor (uint8 exponent) to float32.
|
||||
|
||||
Args:
|
||||
scale_uint8 (torch.Tensor): Scale tensor in uint8 format (exponent + 127).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Scale tensor in float32 format (2^exponent).
|
||||
"""
|
||||
exponent = scale_uint8.to(torch.int32) - 127
|
||||
scale_fp32 = torch.pow(2.0, exponent.to(torch.float32))
|
||||
return scale_fp32
|
||||
|
||||
|
||||
@triton.jit
|
||||
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr, scale_fmt: tl.constexpr):
|
||||
"""
|
||||
@ -23,16 +55,19 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr, scale_fmt: t
|
||||
pid = tl.program_id(axis=0)
|
||||
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
x = tl.load(x_ptr + offs).to(tl.float32)
|
||||
amax = tl.max(tl.abs(x)) # reduction
|
||||
amax = tl.maximum(amax, 1e-4) # clamp to 1e-4
|
||||
amax = tl.max(tl.abs(x))
|
||||
amax = tl.maximum(amax, 1e-4)
|
||||
s = amax / 448.
|
||||
if scale_fmt == "ue8m0":
|
||||
exp = tl.math.ceil(tl.math.log2(s))
|
||||
s = tl.math.exp2(exp)
|
||||
exp_int = exp.to(tl.int32) + 127
|
||||
tl.store(s_ptr + pid, exp_int.to(s_ptr.dtype.element_ty))
|
||||
else:
|
||||
tl.store(s_ptr + pid, s)
|
||||
y = x / s
|
||||
y = y.to(y_ptr.dtype.element_ty)
|
||||
tl.store(y_ptr + offs, y)
|
||||
tl.store(s_ptr + pid, s)
|
||||
|
||||
|
||||
def act_quant(x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@ -51,14 +86,15 @@ def act_quant(x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] =
|
||||
assert x.is_contiguous(), 'Input tensor must be contiguous'
|
||||
assert x.size(-1) % block_size == 0, f'Last dimension size must be divisible by block_size (block_size={block_size})'
|
||||
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
|
||||
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
|
||||
scale_dtype = torch.uint8 if scale_fmt == "ue8m0" else torch.float32
|
||||
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=scale_dtype)
|
||||
grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
|
||||
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size, scale_fmt=scale_fmt)
|
||||
return y, s
|
||||
|
||||
|
||||
@triton.jit
|
||||
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
|
||||
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr, scale_fmt: tl.constexpr):
|
||||
"""
|
||||
Dequantizes weights using the provided scaling factors and stores the result.
|
||||
|
||||
@ -81,32 +117,39 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
|
||||
offs = offs_m[:, None] * N + offs_n[None, :]
|
||||
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
|
||||
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
|
||||
s = tl.load(s_ptr + pid_m * n + pid_n)
|
||||
if scale_fmt == "ue8m0":
|
||||
s_uint8 = tl.load(s_ptr + pid_m * n + pid_n)
|
||||
exp = s_uint8.to(tl.int32) - 127
|
||||
s = tl.math.exp2(exp.to(tl.float32))
|
||||
else:
|
||||
s = tl.load(s_ptr + pid_m * n + pid_n)
|
||||
y = x * s
|
||||
tl.store(y_ptr + offs, y, mask=mask)
|
||||
|
||||
|
||||
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
|
||||
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None) -> torch.Tensor:
|
||||
"""
|
||||
Dequantizes the given weight tensor using the provided scale tensor.
|
||||
Dequantizes the input tensor `x` using the provided scaling factors `s`.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The quantized weight tensor of shape (M, N).
|
||||
s (torch.Tensor): The scale tensor of shape (M//block_size, N//block_size).
|
||||
block_size (int, optional): The block size to use for dequantization. Defaults to 128.
|
||||
x (torch.Tensor): The quantized input tensor.
|
||||
s (torch.Tensor): The scaling factors.
|
||||
block_size (int, optional): The size of the blocks to be used for dequantization. Default is 128.
|
||||
scale_fmt (Optional[str], optional): The format of the scale. Default is None.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The dequantized weight tensor of the same shape as `x`.
|
||||
|
||||
Raises:
|
||||
AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
|
||||
torch.Tensor: The dequantized tensor.
|
||||
"""
|
||||
assert x.is_contiguous() and s.is_contiguous(), 'Input tensors must be contiguous'
|
||||
assert x.dim() == 2 and s.dim() == 2, 'Input tensors must have 2 dimensions'
|
||||
if scale_fmt is None:
|
||||
scale_fmt = "ue8m0" if s.dtype == torch.uint8 else None
|
||||
if scale_fmt == "ue8m0" and s.dtype != torch.uint8:
|
||||
s = convert_scale_to_ue8m0(s)
|
||||
M, N = x.size()
|
||||
y = torch.empty_like(x, dtype=torch.get_default_dtype())
|
||||
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
|
||||
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
|
||||
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size, scale_fmt=scale_fmt)
|
||||
return y
|
||||
|
||||
|
||||
@ -122,22 +165,26 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
|
||||
M, N: tl.constexpr, K: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr):
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
scale_fmt: tl.constexpr):
|
||||
"""
|
||||
Performs a matrix multiplication operation on FP8 matrices with scaling factors.
|
||||
|
||||
Args:
|
||||
a_ptr (tl.tensor): Pointer to the first input matrix A.
|
||||
b_ptr (tl.tensor): Pointer to the second input matrix B.
|
||||
a_ptr (tl.tensor): Pointer to the first input matrix A (FP8).
|
||||
b_ptr (tl.tensor): Pointer to the second input matrix B (FP8).
|
||||
c_ptr (tl.tensor): Pointer to the output matrix C.
|
||||
a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A.
|
||||
(uint8 if scale_fmt=="ue8m0", float32 otherwise)
|
||||
b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B.
|
||||
(uint8 if scale_fmt=="ue8m0", float32 otherwise)
|
||||
M (int): Number of rows in matrix A and C.
|
||||
N (tl.constexpr): Number of columns in matrix B and C.
|
||||
K (tl.constexpr): Number of columns in matrix A and rows in matrix B.
|
||||
BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension.
|
||||
BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension.
|
||||
BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension.
|
||||
scale_fmt (tl.constexpr): Scale format ("ue8m0" for uint8 exponent format, None for float32).
|
||||
|
||||
Returns:
|
||||
None
|
||||
@ -157,9 +204,20 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
|
||||
for i in range(k):
|
||||
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0)
|
||||
a_s = tl.load(a_s_ptrs)
|
||||
b_s = tl.load(b_s_ptrs)
|
||||
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
||||
|
||||
if scale_fmt == "ue8m0":
|
||||
a_s_uint8 = tl.load(a_s_ptrs)
|
||||
b_s_uint8 = tl.load(b_s_ptrs)
|
||||
a_exp = a_s_uint8.to(tl.int32) - 127
|
||||
b_exp = b_s_uint8.to(tl.int32) - 127
|
||||
combined_exp = a_exp[:, None] + b_exp[None, :]
|
||||
scale = tl.math.exp2(combined_exp.to(tl.float32))
|
||||
accumulator += tl.dot(a, b) * scale
|
||||
else:
|
||||
a_s = tl.load(a_s_ptrs)
|
||||
b_s = tl.load(b_s_ptrs)
|
||||
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
||||
|
||||
a_ptrs += BLOCK_SIZE_K
|
||||
b_ptrs += BLOCK_SIZE_K
|
||||
a_s_ptrs += 1
|
||||
@ -172,7 +230,7 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
|
||||
tl.store(c_ptrs, c, mask=mask)
|
||||
|
||||
|
||||
def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor):
|
||||
def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor, scale_fmt: Optional[str] = None):
|
||||
"""
|
||||
Perform a matrix multiplication using FP8 precision.
|
||||
|
||||
@ -181,16 +239,28 @@ def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Ten
|
||||
a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.
|
||||
b (torch.Tensor): The second input matrix, must be contiguous.
|
||||
b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.
|
||||
scale_fmt (Optional[str], optional): The format of the scale. Default is None.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The result of the matrix multiplication.
|
||||
"""
|
||||
assert a.is_contiguous() and b.is_contiguous(), 'Input tensors must be contiguous'
|
||||
assert a_s.is_contiguous() and b_s.is_contiguous(), 'Scaling factor tensors must be contiguous'
|
||||
if scale_fmt is None:
|
||||
if a_s.dtype == torch.uint8 or b_s.dtype == torch.uint8:
|
||||
scale_fmt = "ue8m0"
|
||||
else:
|
||||
scale_fmt = None
|
||||
if scale_fmt == "ue8m0":
|
||||
if a_s.dtype != torch.uint8:
|
||||
a_s = convert_scale_to_ue8m0(a_s)
|
||||
if b_s.dtype != torch.uint8:
|
||||
b_s = convert_scale_to_ue8m0(b_s)
|
||||
|
||||
K = a.size(-1)
|
||||
M = a.numel() // K
|
||||
N = b.size(0)
|
||||
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
|
||||
fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
|
||||
fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K, scale_fmt=scale_fmt)
|
||||
return c
|
||||
|
||||
@ -131,33 +131,24 @@ class ParallelEmbedding(nn.Module):
|
||||
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, scale_fmt: Optional[str] = None) -> torch.Tensor:
|
||||
"""
|
||||
Applies a linear transformation to the incoming data: y = xA^T + b.
|
||||
This function supports specialized implementations based on quantization
|
||||
and tensor formats.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
weight (torch.Tensor): The weight tensor. It may be quantized and
|
||||
requires dequantization for certain cases.
|
||||
weight (torch.Tensor): The weight tensor.
|
||||
bias (Optional[torch.Tensor]): The bias tensor to be added. Default is None.
|
||||
scale_fmt (Optional[str], optional): The format of the scale. Default is None.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The result of the linear transformation, which may involve
|
||||
quantization-aware computations depending on the input parameters.
|
||||
|
||||
Notes:
|
||||
- If `weight` is quantized (e.g., `element_size() == 1`), a dequantized version
|
||||
is used for computation.
|
||||
- If `gemm_impl == "bf16"`, dequantization and a `bf16` GEMM operation are applied.
|
||||
- For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation.
|
||||
torch.Tensor: The result of the linear transformation.
|
||||
"""
|
||||
if weight.element_size() > 1:
|
||||
return F.linear(x, weight, bias)
|
||||
elif gemm_impl == "bf16":
|
||||
weight = weight_dequant(weight, weight.scale)
|
||||
weight = weight_dequant(weight, weight.scale, scale_fmt=scale_fmt)
|
||||
return F.linear(x, weight, bias)
|
||||
else:
|
||||
x, scale = act_quant(x, block_size, scale_fmt)
|
||||
y = fp8_gemm(x, scale, weight, weight.scale)
|
||||
y = fp8_gemm(x, scale, weight, weight.scale, scale_fmt)
|
||||
if bias is not None:
|
||||
y += bias
|
||||
return y
|
||||
@ -478,7 +469,7 @@ class MLA(nn.Module):
|
||||
self.v_cache[:bsz, start_pos:end_pos] = v
|
||||
scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale
|
||||
else:
|
||||
wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size)
|
||||
wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size, Linear.scale_fmt)
|
||||
wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
|
||||
q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
|
||||
self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user