mirror of
https://git.datalinker.icu/deepseek-ai/DeepSeek-V3.git
synced 2025-12-09 04:44:28 +08:00
Compare commits
10 Commits
c4e3479282
...
b392f0f010
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b392f0f010 | ||
|
|
9b4e9788e4 | ||
|
|
adecc0efbe | ||
|
|
82f6008c8c | ||
|
|
b15f0dbbbe | ||
|
|
4592be48c0 | ||
|
|
d700e0056d | ||
|
|
361d0bcc1c | ||
|
|
35703ca641 | ||
|
|
d3be6c9d91 |
Binary file not shown.
|
Before Width: | Height: | Size: 179 KiB After Width: | Height: | Size: 100 KiB |
BIN
figures/niah.png
BIN
figures/niah.png
Binary file not shown.
|
Before Width: | Height: | Size: 106 KiB After Width: | Height: | Size: 78 KiB |
23
inference/configs/config_v3.1.json
Normal file
23
inference/configs/config_v3.1.json
Normal file
@ -0,0 +1,23 @@
|
||||
{
|
||||
"vocab_size": 129280,
|
||||
"dim": 7168,
|
||||
"inter_dim": 18432,
|
||||
"moe_inter_dim": 2048,
|
||||
"n_layers": 61,
|
||||
"n_dense_layers": 3,
|
||||
"n_heads": 128,
|
||||
"n_routed_experts": 256,
|
||||
"n_shared_experts": 1,
|
||||
"n_activated_experts": 8,
|
||||
"n_expert_groups": 8,
|
||||
"n_limited_groups": 4,
|
||||
"route_scale": 2.5,
|
||||
"score_func": "sigmoid",
|
||||
"q_lora_rank": 1536,
|
||||
"kv_lora_rank": 512,
|
||||
"qk_nope_head_dim": 128,
|
||||
"qk_rope_head_dim": 64,
|
||||
"v_head_dim": 128,
|
||||
"dtype": "fp8",
|
||||
"scale_fmt": "ue8m0"
|
||||
}
|
||||
@ -3,7 +3,6 @@ import shutil
|
||||
from argparse import ArgumentParser
|
||||
from glob import glob
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
import torch
|
||||
from safetensors.torch import safe_open, save_file
|
||||
|
||||
@ -30,7 +29,7 @@ mapping = {
|
||||
}
|
||||
|
||||
|
||||
def main(hf_ckpt_path, save_path, n_experts, mp):
|
||||
def main(hf_ckpt_path: str, save_path: str, n_experts: int, mp: int) -> None:
|
||||
"""
|
||||
Converts and saves model checkpoint files into a specified format.
|
||||
|
||||
@ -43,46 +42,50 @@ def main(hf_ckpt_path, save_path, n_experts, mp):
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
torch.set_num_threads(8)
|
||||
n_local_experts = n_experts // mp
|
||||
state_dicts = [{} for _ in range(mp)]
|
||||
try:
|
||||
torch.set_num_threads(8)
|
||||
n_local_experts = n_experts // mp
|
||||
state_dicts = [{} for _ in range(mp)]
|
||||
|
||||
for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))):
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
for name in f.keys():
|
||||
if "model.layers.61" in name:
|
||||
continue
|
||||
param: torch.Tensor = f.get_tensor(name)
|
||||
if name.startswith("model."):
|
||||
name = name[len("model."):]
|
||||
name = name.replace("self_attn", "attn")
|
||||
name = name.replace("mlp", "ffn")
|
||||
name = name.replace("weight_scale_inv", "scale")
|
||||
name = name.replace("e_score_correction_bias", "bias")
|
||||
key = name.split(".")[-2]
|
||||
assert key in mapping, f"Key {key} not found in mapping"
|
||||
new_key, dim = mapping[key]
|
||||
name = name.replace(key, new_key)
|
||||
for i in range(mp):
|
||||
new_param = param
|
||||
if "experts" in name and "shared_experts" not in name:
|
||||
idx = int(name.split(".")[-3])
|
||||
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
|
||||
continue
|
||||
elif dim is not None:
|
||||
assert param.size(dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}"
|
||||
shard_size = param.size(dim) // mp
|
||||
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
|
||||
state_dicts[i][name] = new_param
|
||||
for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))):
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
for name in f.keys():
|
||||
if "model.layers.61" in name:
|
||||
continue
|
||||
param: torch.Tensor = f.get_tensor(name)
|
||||
if name.startswith("model."):
|
||||
name = name[len("model."):]
|
||||
name = name.replace("self_attn", "attn")
|
||||
name = name.replace("mlp", "ffn")
|
||||
name = name.replace("weight_scale_inv", "scale")
|
||||
name = name.replace("e_score_correction_bias", "bias")
|
||||
key = name.split(".")[-2]
|
||||
assert key in mapping, f"Key {key} not found in mapping"
|
||||
new_key, dim = mapping[key]
|
||||
name = name.replace(key, new_key)
|
||||
for i in range(mp):
|
||||
new_param = param
|
||||
if "experts" in name and "shared_experts" not in name:
|
||||
idx = int(name.split(".")[-3])
|
||||
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
|
||||
continue
|
||||
elif dim is not None:
|
||||
assert param.size(dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}"
|
||||
shard_size = param.size(dim) // mp
|
||||
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
|
||||
state_dicts[i][name] = new_param
|
||||
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
|
||||
for i in trange(mp):
|
||||
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
|
||||
for i in trange(mp):
|
||||
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
|
||||
|
||||
for file_path in glob(os.path.join(hf_ckpt_path, "*token*")):
|
||||
new_file_path = os.path.join(save_path, os.path.basename(file_path))
|
||||
shutil.copyfile(file_path, new_file_path)
|
||||
for file_path in glob(os.path.join(hf_ckpt_path, "*token*")):
|
||||
new_file_path = os.path.join(save_path, os.path.basename(file_path))
|
||||
shutil.copyfile(file_path, new_file_path)
|
||||
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Tuple
|
||||
from typing import Tuple, Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
@ -7,7 +7,7 @@ from triton import Config
|
||||
|
||||
|
||||
@triton.jit
|
||||
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
||||
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr, scale_fmt: tl.constexpr):
|
||||
"""
|
||||
Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.
|
||||
|
||||
@ -23,21 +23,26 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
x = tl.load(x_ptr + offs).to(tl.float32)
|
||||
s = tl.max(tl.abs(x)) / 448.
|
||||
amax = tl.max(tl.abs(x)) # reduction
|
||||
amax = tl.maximum(amax, 1e-4) # clamp to 1e-4
|
||||
s = amax / 448.
|
||||
if scale_fmt == "ue8m0":
|
||||
exp = tl.math.ceil(tl.math.log2(s))
|
||||
s = tl.math.exp2(exp)
|
||||
y = x / s
|
||||
y = y.to(y_ptr.dtype.element_ty)
|
||||
tl.store(y_ptr + offs, y)
|
||||
tl.store(s_ptr + pid, s)
|
||||
|
||||
|
||||
def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def act_quant(x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Quantizes the input tensor `x` using block-wise quantization.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
|
||||
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
|
||||
|
||||
scale_fmt (Optional[str], optional): The format of the scale. Default is None.
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
||||
- The quantized tensor with dtype `torch.float8_e4m3fn`.
|
||||
@ -48,7 +53,7 @@ def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, tor
|
||||
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
|
||||
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
|
||||
grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
|
||||
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
|
||||
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size, scale_fmt=scale_fmt)
|
||||
return y, s
|
||||
|
||||
|
||||
|
||||
@ -25,6 +25,7 @@ class ModelArgs:
|
||||
max_batch_size (int): Maximum batch size.
|
||||
max_seq_len (int): Maximum sequence length.
|
||||
dtype (Literal["bf16", "fp8"]): Data type for computations.
|
||||
scale_fmt (Optional[str]): Format for quantization scale.
|
||||
vocab_size (int): Vocabulary size.
|
||||
dim (int): Model dimension.
|
||||
inter_dim (int): Intermediate dimension for MLP layers.
|
||||
@ -54,6 +55,7 @@ class ModelArgs:
|
||||
max_batch_size: int = 8
|
||||
max_seq_len: int = 4096 * 4
|
||||
dtype: Literal["bf16", "fp8"] = "bf16"
|
||||
scale_fmt: Optional[str] = None
|
||||
vocab_size: int = 102400
|
||||
dim: int = 2048
|
||||
inter_dim: int = 10944
|
||||
@ -126,7 +128,7 @@ class ParallelEmbedding(nn.Module):
|
||||
return y
|
||||
|
||||
|
||||
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, scale_fmt: Optional[str] = None) -> torch.Tensor:
|
||||
"""
|
||||
Applies a linear transformation to the incoming data: y = xA^T + b.
|
||||
This function supports specialized implementations based on quantization
|
||||
@ -154,7 +156,7 @@ def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] =
|
||||
weight = weight_dequant(weight, weight.scale)
|
||||
return F.linear(x, weight, bias)
|
||||
else:
|
||||
x, scale = act_quant(x, block_size)
|
||||
x, scale = act_quant(x, block_size, scale_fmt)
|
||||
y = fp8_gemm(x, scale, weight, weight.scale)
|
||||
if bias is not None:
|
||||
y += bias
|
||||
@ -172,6 +174,7 @@ class Linear(nn.Module):
|
||||
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
||||
"""
|
||||
dtype = torch.bfloat16
|
||||
scale_fmt: Optional[str] = None
|
||||
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
||||
super().__init__()
|
||||
@ -199,7 +202,7 @@ class Linear(nn.Module):
|
||||
Returns:
|
||||
torch.Tensor: Transformed tensor after linear computation.
|
||||
"""
|
||||
return linear(x, self.weight, self.bias)
|
||||
return linear(x, self.weight, self.bias, self.scale_fmt)
|
||||
|
||||
|
||||
class ColumnParallelLinear(Linear):
|
||||
@ -558,7 +561,7 @@ class Gate(nn.Module):
|
||||
self.score_func = args.score_func
|
||||
self.route_scale = args.route_scale
|
||||
self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
|
||||
self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None
|
||||
self.bias = nn.Parameter(torch.empty(args.n_routed_experts, dtype=torch.float32)) if self.dim == 7168 else None
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
@ -755,6 +758,7 @@ class Transformer(nn.Module):
|
||||
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
rank = dist.get_rank() if dist.is_initialized() else 0
|
||||
Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
|
||||
Linear.scale_fmt = args.scale_fmt
|
||||
super().__init__()
|
||||
self.max_seq_len = args.max_seq_len
|
||||
self.embed = ParallelEmbedding(args.vocab_size, args.dim)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user