mirror of
https://git.datalinker.icu/deepseek-ai/DeepSeek-V3.git
synced 2025-12-09 21:04:36 +08:00
keep improving
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
3745dc5ab6
commit
348e741a11
23
inference/configs/config_671B_v3.1.json
Normal file
23
inference/configs/config_671B_v3.1.json
Normal file
@ -0,0 +1,23 @@
|
||||
{
|
||||
"vocab_size": 129280,
|
||||
"dim": 7168,
|
||||
"inter_dim": 18432,
|
||||
"moe_inter_dim": 2048,
|
||||
"n_layers": 61,
|
||||
"n_dense_layers": 3,
|
||||
"n_heads": 128,
|
||||
"n_routed_experts": 256,
|
||||
"n_shared_experts": 1,
|
||||
"n_activated_experts": 8,
|
||||
"n_expert_groups": 8,
|
||||
"n_limited_groups": 4,
|
||||
"route_scale": 2.5,
|
||||
"score_func": "sigmoid",
|
||||
"q_lora_rank": 1536,
|
||||
"kv_lora_rank": 512,
|
||||
"qk_nope_head_dim": 128,
|
||||
"qk_rope_head_dim": 64,
|
||||
"v_head_dim": 128,
|
||||
"dtype": "fp8",
|
||||
"scale_fmt": "ue8m0"
|
||||
}
|
||||
@ -112,10 +112,6 @@ def main(
|
||||
with open(config) as f:
|
||||
config_dict = json.load(f)
|
||||
args = ModelArgs(**config_dict)
|
||||
quantization_config = config_dict.get("quantization_config", None)
|
||||
if quantization_config is not None:
|
||||
args.scale_fmt = quantization_config.get("scale_fmt", None)
|
||||
set_global_args(args)
|
||||
print(args)
|
||||
with torch.device("cuda"):
|
||||
model = Transformer(args)
|
||||
|
||||
@ -85,12 +85,6 @@ class ModelArgs:
|
||||
beta_slow: int = 1
|
||||
mscale: float = 1.
|
||||
|
||||
global_args: Optional[ModelArgs] = None
|
||||
|
||||
def set_global_args(args: ModelArgs):
|
||||
global global_args
|
||||
global_args = args
|
||||
|
||||
|
||||
class ParallelEmbedding(nn.Module):
|
||||
"""
|
||||
@ -134,7 +128,7 @@ class ParallelEmbedding(nn.Module):
|
||||
return y
|
||||
|
||||
|
||||
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, scale_fmt: Optional[str] = None) -> torch.Tensor:
|
||||
"""
|
||||
Applies a linear transformation to the incoming data: y = xA^T + b.
|
||||
This function supports specialized implementations based on quantization
|
||||
@ -162,8 +156,7 @@ def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] =
|
||||
weight = weight_dequant(weight, weight.scale)
|
||||
return F.linear(x, weight, bias)
|
||||
else:
|
||||
assert global_args is not None, "global_args is required for fp8_gemm"
|
||||
x, scale = act_quant(x, block_size, global_args.scale_fmt)
|
||||
x, scale = act_quant(x, block_size, scale_fmt)
|
||||
y = fp8_gemm(x, scale, weight, weight.scale)
|
||||
if bias is not None:
|
||||
y += bias
|
||||
@ -181,6 +174,7 @@ class Linear(nn.Module):
|
||||
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
||||
"""
|
||||
dtype = torch.bfloat16
|
||||
scale_fmt: Optional[str] = None
|
||||
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
||||
super().__init__()
|
||||
@ -208,7 +202,7 @@ class Linear(nn.Module):
|
||||
Returns:
|
||||
torch.Tensor: Transformed tensor after linear computation.
|
||||
"""
|
||||
return linear(x, self.weight, self.bias)
|
||||
return linear(x, self.weight, self.bias, self.scale_fmt)
|
||||
|
||||
|
||||
class ColumnParallelLinear(Linear):
|
||||
@ -764,6 +758,7 @@ class Transformer(nn.Module):
|
||||
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
rank = dist.get_rank() if dist.is_initialized() else 0
|
||||
Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
|
||||
Linear.scale_fmt = args.scale_fmt
|
||||
super().__init__()
|
||||
self.max_seq_len = args.max_seq_len
|
||||
self.embed = ParallelEmbedding(args.vocab_size, args.dim)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user