mirror of
https://git.datalinker.icu/deepseek-ai/DeepSeek-V3.git
synced 2025-12-14 23:34:32 +08:00
support scale_fmt=ue8m0
This commit is contained in:
parent
f6e34dd267
commit
3745dc5ab6
@ -8,7 +8,7 @@ import torch.distributed as dist
|
|||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
from safetensors.torch import load_model
|
from safetensors.torch import load_model
|
||||||
|
|
||||||
from model import Transformer, ModelArgs
|
from model import Transformer, ModelArgs, set_global_args
|
||||||
|
|
||||||
|
|
||||||
def sample(logits, temperature: float = 1.0):
|
def sample(logits, temperature: float = 1.0):
|
||||||
@ -110,7 +110,12 @@ def main(
|
|||||||
torch.set_num_threads(8)
|
torch.set_num_threads(8)
|
||||||
torch.manual_seed(965)
|
torch.manual_seed(965)
|
||||||
with open(config) as f:
|
with open(config) as f:
|
||||||
args = ModelArgs(**json.load(f))
|
config_dict = json.load(f)
|
||||||
|
args = ModelArgs(**config_dict)
|
||||||
|
quantization_config = config_dict.get("quantization_config", None)
|
||||||
|
if quantization_config is not None:
|
||||||
|
args.scale_fmt = quantization_config.get("scale_fmt", None)
|
||||||
|
set_global_args(args)
|
||||||
print(args)
|
print(args)
|
||||||
with torch.device("cuda"):
|
with torch.device("cuda"):
|
||||||
model = Transformer(args)
|
model = Transformer(args)
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Tuple
|
from typing import Tuple, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
@ -7,7 +7,7 @@ from triton import Config
|
|||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr, scale_fmt: tl.constexpr):
|
||||||
"""
|
"""
|
||||||
Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.
|
Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.
|
||||||
|
|
||||||
@ -24,20 +24,23 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
|||||||
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
x = tl.load(x_ptr + offs).to(tl.float32)
|
x = tl.load(x_ptr + offs).to(tl.float32)
|
||||||
s = tl.max(tl.abs(x)) / 448.
|
s = tl.max(tl.abs(x)) / 448.
|
||||||
|
if scale_fmt == "ue8m0":
|
||||||
|
exp = tl.math.ceil(tl.math.log2(s))
|
||||||
|
s = tl.math.exp2(exp)
|
||||||
y = x / s
|
y = x / s
|
||||||
y = y.to(y_ptr.dtype.element_ty)
|
y = y.to(y_ptr.dtype.element_ty)
|
||||||
tl.store(y_ptr + offs, y)
|
tl.store(y_ptr + offs, y)
|
||||||
tl.store(s_ptr + pid, s)
|
tl.store(s_ptr + pid, s)
|
||||||
|
|
||||||
|
|
||||||
def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
def act_quant(x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Quantizes the input tensor `x` using block-wise quantization.
|
Quantizes the input tensor `x` using block-wise quantization.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
|
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
|
||||||
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
|
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
|
||||||
|
scale_fmt (Optional[str], optional): The format of the scale. Default is None.
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
||||||
- The quantized tensor with dtype `torch.float8_e4m3fn`.
|
- The quantized tensor with dtype `torch.float8_e4m3fn`.
|
||||||
@ -48,7 +51,7 @@ def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, tor
|
|||||||
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
|
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
|
||||||
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
|
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
|
||||||
grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
|
grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
|
||||||
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
|
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size, scale_fmt=scale_fmt)
|
||||||
return y, s
|
return y, s
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -25,6 +25,7 @@ class ModelArgs:
|
|||||||
max_batch_size (int): Maximum batch size.
|
max_batch_size (int): Maximum batch size.
|
||||||
max_seq_len (int): Maximum sequence length.
|
max_seq_len (int): Maximum sequence length.
|
||||||
dtype (Literal["bf16", "fp8"]): Data type for computations.
|
dtype (Literal["bf16", "fp8"]): Data type for computations.
|
||||||
|
scale_fmt (Optional[str]): Format for quantization scale.
|
||||||
vocab_size (int): Vocabulary size.
|
vocab_size (int): Vocabulary size.
|
||||||
dim (int): Model dimension.
|
dim (int): Model dimension.
|
||||||
inter_dim (int): Intermediate dimension for MLP layers.
|
inter_dim (int): Intermediate dimension for MLP layers.
|
||||||
@ -54,6 +55,7 @@ class ModelArgs:
|
|||||||
max_batch_size: int = 8
|
max_batch_size: int = 8
|
||||||
max_seq_len: int = 4096 * 4
|
max_seq_len: int = 4096 * 4
|
||||||
dtype: Literal["bf16", "fp8"] = "bf16"
|
dtype: Literal["bf16", "fp8"] = "bf16"
|
||||||
|
scale_fmt: Optional[str] = None
|
||||||
vocab_size: int = 102400
|
vocab_size: int = 102400
|
||||||
dim: int = 2048
|
dim: int = 2048
|
||||||
inter_dim: int = 10944
|
inter_dim: int = 10944
|
||||||
@ -83,6 +85,12 @@ class ModelArgs:
|
|||||||
beta_slow: int = 1
|
beta_slow: int = 1
|
||||||
mscale: float = 1.
|
mscale: float = 1.
|
||||||
|
|
||||||
|
global_args: Optional[ModelArgs] = None
|
||||||
|
|
||||||
|
def set_global_args(args: ModelArgs):
|
||||||
|
global global_args
|
||||||
|
global_args = args
|
||||||
|
|
||||||
|
|
||||||
class ParallelEmbedding(nn.Module):
|
class ParallelEmbedding(nn.Module):
|
||||||
"""
|
"""
|
||||||
@ -154,7 +162,8 @@ def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] =
|
|||||||
weight = weight_dequant(weight, weight.scale)
|
weight = weight_dequant(weight, weight.scale)
|
||||||
return F.linear(x, weight, bias)
|
return F.linear(x, weight, bias)
|
||||||
else:
|
else:
|
||||||
x, scale = act_quant(x, block_size)
|
assert global_args is not None, "global_args is required for fp8_gemm"
|
||||||
|
x, scale = act_quant(x, block_size, global_args.scale_fmt)
|
||||||
y = fp8_gemm(x, scale, weight, weight.scale)
|
y = fp8_gemm(x, scale, weight, weight.scale)
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
y += bias
|
y += bias
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user