From 348e741a11f4d0885744859edde7f81ad006f7ef Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 26 Aug 2025 18:08:50 +0800 Subject: [PATCH] keep improving Signed-off-by: youkaichao --- inference/configs/config_671B_v3.1.json | 23 +++++++++++++++++++++++ inference/generate.py | 4 ---- inference/model.py | 15 +++++---------- 3 files changed, 28 insertions(+), 14 deletions(-) create mode 100644 inference/configs/config_671B_v3.1.json diff --git a/inference/configs/config_671B_v3.1.json b/inference/configs/config_671B_v3.1.json new file mode 100644 index 0000000..091d4cc --- /dev/null +++ b/inference/configs/config_671B_v3.1.json @@ -0,0 +1,23 @@ +{ + "vocab_size": 129280, + "dim": 7168, + "inter_dim": 18432, + "moe_inter_dim": 2048, + "n_layers": 61, + "n_dense_layers": 3, + "n_heads": 128, + "n_routed_experts": 256, + "n_shared_experts": 1, + "n_activated_experts": 8, + "n_expert_groups": 8, + "n_limited_groups": 4, + "route_scale": 2.5, + "score_func": "sigmoid", + "q_lora_rank": 1536, + "kv_lora_rank": 512, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "v_head_dim": 128, + "dtype": "fp8", + "scale_fmt": "ue8m0" +} \ No newline at end of file diff --git a/inference/generate.py b/inference/generate.py index 7a811e6..deb594e 100644 --- a/inference/generate.py +++ b/inference/generate.py @@ -112,10 +112,6 @@ def main( with open(config) as f: config_dict = json.load(f) args = ModelArgs(**config_dict) - quantization_config = config_dict.get("quantization_config", None) - if quantization_config is not None: - args.scale_fmt = quantization_config.get("scale_fmt", None) - set_global_args(args) print(args) with torch.device("cuda"): model = Transformer(args) diff --git a/inference/model.py b/inference/model.py index 2e0af14..28311cd 100644 --- a/inference/model.py +++ b/inference/model.py @@ -85,12 +85,6 @@ class ModelArgs: beta_slow: int = 1 mscale: float = 1. -global_args: Optional[ModelArgs] = None - -def set_global_args(args: ModelArgs): - global global_args - global_args = args - class ParallelEmbedding(nn.Module): """ @@ -134,7 +128,7 @@ class ParallelEmbedding(nn.Module): return y -def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: +def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, scale_fmt: Optional[str] = None) -> torch.Tensor: """ Applies a linear transformation to the incoming data: y = xA^T + b. This function supports specialized implementations based on quantization @@ -162,8 +156,7 @@ def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = weight = weight_dequant(weight, weight.scale) return F.linear(x, weight, bias) else: - assert global_args is not None, "global_args is required for fp8_gemm" - x, scale = act_quant(x, block_size, global_args.scale_fmt) + x, scale = act_quant(x, block_size, scale_fmt) y = fp8_gemm(x, scale, weight, weight.scale) if bias is not None: y += bias @@ -181,6 +174,7 @@ class Linear(nn.Module): dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. """ dtype = torch.bfloat16 + scale_fmt: Optional[str] = None def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): super().__init__() @@ -208,7 +202,7 @@ class Linear(nn.Module): Returns: torch.Tensor: Transformed tensor after linear computation. """ - return linear(x, self.weight, self.bias) + return linear(x, self.weight, self.bias, self.scale_fmt) class ColumnParallelLinear(Linear): @@ -764,6 +758,7 @@ class Transformer(nn.Module): world_size = dist.get_world_size() if dist.is_initialized() else 1 rank = dist.get_rank() if dist.is_initialized() else 0 Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16 + Linear.scale_fmt = args.scale_fmt super().__init__() self.max_seq_len = args.max_seq_len self.embed = ParallelEmbedding(args.vocab_size, args.dim)