mirror of
https://git.datalinker.icu/deepseek-ai/DeepSeek-V3.git
synced 2026-05-23 03:09:08 +08:00
Merge 6847b3d86692438b63a831f7649705c18d7d333b into 9b4e9788e4a3a731f7567338ed15d3ec549ce03b
This commit is contained in:
commit
ab76fa01ca
@ -246,6 +246,11 @@ cd inference
|
|||||||
python fp8_cast_bf16.py --input-fp8-hf-path /path/to/fp8_weights --output-bf16-hf-path /path/to/bf16_weights
|
python fp8_cast_bf16.py --input-fp8-hf-path /path/to/fp8_weights --output-bf16-hf-path /path/to/bf16_weights
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> The conversion script now supports optional arguments for broader portability:
|
||||||
|
> - `--device`: `auto` (default), `cuda`, or `cpu`. When set to `auto`, the script prefers CUDA if available, otherwise falls back to CPU. Use `cpu` to force a CPU-only path on systems without GPUs.
|
||||||
|
> - `--block-size`: Block size used during quantization/dequantization (default `128`). This should match the model’s tiling settings at export time.
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> Hugging Face's Transformers has not been directly supported yet.
|
> Hugging Face's Transformers has not been directly supported yet.
|
||||||
|
|
||||||
|
|||||||
@ -2,14 +2,47 @@ import os
|
|||||||
import json
|
import json
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from glob import glob
|
from glob import glob
|
||||||
|
from typing import Dict
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from safetensors.torch import load_file, save_file
|
from safetensors.torch import load_file, save_file
|
||||||
|
|
||||||
from kernel import weight_dequant
|
# Use the optimized Triton GPU kernel when available; add a CPU fallback below
|
||||||
|
from kernel import weight_dequant as weight_dequant_gpu
|
||||||
|
|
||||||
def main(fp8_path, bf16_path):
|
def _infer_device(requested_device: str) -> str:
|
||||||
|
"""
|
||||||
|
Decide which device to use for tensor I/O and dequantization.
|
||||||
|
|
||||||
|
- "auto": prefer CUDA if available, otherwise CPU
|
||||||
|
- "cuda": use CUDA when available, else fall back to CPU
|
||||||
|
- "cpu": force CPU
|
||||||
|
"""
|
||||||
|
if requested_device == "cuda":
|
||||||
|
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
if requested_device == "cpu":
|
||||||
|
return "cpu"
|
||||||
|
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def _weight_dequant_cpu(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
CPU fallback for FP8 weight dequantization using pure PyTorch.
|
||||||
|
|
||||||
|
Expands per-block scales to a full resolution map and rescales FP8 weights.
|
||||||
|
Prioritizes correctness and portability; may be memory intensive for huge tensors.
|
||||||
|
"""
|
||||||
|
assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions"
|
||||||
|
M, N = x.shape
|
||||||
|
m_blocks, n_blocks = s.shape
|
||||||
|
assert M % block_size == 0 and N % block_size == 0, "Weight dims must be multiples of block_size"
|
||||||
|
assert m_blocks == M // block_size and n_blocks == N // block_size, "Scale shape must match weight tiling"
|
||||||
|
s_full = s.repeat_interleave(block_size, dim=0).repeat_interleave(block_size, dim=1)
|
||||||
|
y = x.to(torch.float32) * s_full
|
||||||
|
return y.to(torch.get_default_dtype())
|
||||||
|
|
||||||
|
def convert_fp8_to_bf16(fp8_path: str, bf16_path: str, device: str = "auto", block_size: int = 128):
|
||||||
"""
|
"""
|
||||||
Converts FP8 weights to BF16 and saves the converted weights.
|
Converts FP8 weights to BF16 and saves the converted weights.
|
||||||
|
|
||||||
@ -20,6 +53,8 @@ def main(fp8_path, bf16_path):
|
|||||||
Args:
|
Args:
|
||||||
fp8_path (str): The path to the directory containing the FP8 weights and model index file.
|
fp8_path (str): The path to the directory containing the FP8 weights and model index file.
|
||||||
bf16_path (str): The path to the directory where the converted BF16 weights will be saved.
|
bf16_path (str): The path to the directory where the converted BF16 weights will be saved.
|
||||||
|
device (str): One of {"auto", "cuda", "cpu"}. Controls load/compute device.
|
||||||
|
block_size (int): Block size used by quantization/dequantization. Typically 128.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
KeyError: If a required scale_inv tensor is missing for a weight.
|
KeyError: If a required scale_inv tensor is missing for a weight.
|
||||||
@ -31,13 +66,14 @@ def main(fp8_path, bf16_path):
|
|||||||
"""
|
"""
|
||||||
torch.set_default_dtype(torch.bfloat16)
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
os.makedirs(bf16_path, exist_ok=True)
|
os.makedirs(bf16_path, exist_ok=True)
|
||||||
|
effective_device = _infer_device(device)
|
||||||
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
|
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
|
||||||
with open(model_index_file, "r") as f:
|
with open(model_index_file, "r") as f:
|
||||||
model_index = json.load(f)
|
model_index = json.load(f)
|
||||||
weight_map = model_index["weight_map"]
|
weight_map = model_index["weight_map"]
|
||||||
|
|
||||||
# Cache for loaded safetensor files
|
# Cache for loaded safetensor files
|
||||||
loaded_files = {}
|
loaded_files: Dict[str, Dict[str, torch.Tensor]] = {}
|
||||||
fp8_weight_names = []
|
fp8_weight_names = []
|
||||||
|
|
||||||
# Helper function to get tensor from the correct file
|
# Helper function to get tensor from the correct file
|
||||||
@ -57,14 +93,14 @@ def main(fp8_path, bf16_path):
|
|||||||
file_name = weight_map[tensor_name]
|
file_name = weight_map[tensor_name]
|
||||||
if file_name not in loaded_files:
|
if file_name not in loaded_files:
|
||||||
file_path = os.path.join(fp8_path, file_name)
|
file_path = os.path.join(fp8_path, file_name)
|
||||||
loaded_files[file_name] = load_file(file_path, device="cuda")
|
loaded_files[file_name] = load_file(file_path, device=effective_device)
|
||||||
return loaded_files[file_name][tensor_name]
|
return loaded_files[file_name][tensor_name]
|
||||||
|
|
||||||
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
|
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
|
||||||
safetensor_files.sort()
|
safetensor_files.sort()
|
||||||
for safetensor_file in tqdm(safetensor_files):
|
for safetensor_file in tqdm(safetensor_files):
|
||||||
file_name = os.path.basename(safetensor_file)
|
file_name = os.path.basename(safetensor_file)
|
||||||
current_state_dict = load_file(safetensor_file, device="cuda")
|
current_state_dict = load_file(safetensor_file, device=effective_device)
|
||||||
loaded_files[file_name] = current_state_dict
|
loaded_files[file_name] = current_state_dict
|
||||||
|
|
||||||
new_state_dict = {}
|
new_state_dict = {}
|
||||||
@ -77,7 +113,10 @@ def main(fp8_path, bf16_path):
|
|||||||
# Get scale_inv from the correct file
|
# Get scale_inv from the correct file
|
||||||
scale_inv = get_tensor(scale_inv_name)
|
scale_inv = get_tensor(scale_inv_name)
|
||||||
fp8_weight_names.append(weight_name)
|
fp8_weight_names.append(weight_name)
|
||||||
new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
|
if effective_device == "cuda":
|
||||||
|
new_state_dict[weight_name] = weight_dequant_gpu(weight, scale_inv, block_size=block_size)
|
||||||
|
else:
|
||||||
|
new_state_dict[weight_name] = _weight_dequant_cpu(weight, scale_inv, block_size=block_size)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
|
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
|
||||||
new_state_dict[weight_name] = weight
|
new_state_dict[weight_name] = weight
|
||||||
@ -91,7 +130,8 @@ def main(fp8_path, bf16_path):
|
|||||||
if len(loaded_files) > 2:
|
if len(loaded_files) > 2:
|
||||||
oldest_file = next(iter(loaded_files))
|
oldest_file = next(iter(loaded_files))
|
||||||
del loaded_files[oldest_file]
|
del loaded_files[oldest_file]
|
||||||
torch.cuda.empty_cache()
|
if effective_device == "cuda":
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
# Update model index
|
# Update model index
|
||||||
new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
|
new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
|
||||||
@ -107,6 +147,10 @@ if __name__ == "__main__":
|
|||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
parser.add_argument("--input-fp8-hf-path", type=str, required=True)
|
parser.add_argument("--input-fp8-hf-path", type=str, required=True)
|
||||||
parser.add_argument("--output-bf16-hf-path", type=str, required=True)
|
parser.add_argument("--output-bf16-hf-path", type=str, required=True)
|
||||||
|
parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu"],
|
||||||
|
help="Select compute device: 'auto' prefers CUDA if available; otherwise CPU.")
|
||||||
|
parser.add_argument("--block-size", type=int, default=128,
|
||||||
|
help="Block size used during quantization/dequantization. Typically 128.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args.input_fp8_hf_path, args.output_bf16_hf_path)
|
convert_fp8_to_bf16(args.input_fp8_hf_path, args.output_bf16_hf_path, device=args.device, block_size=args.block_size)
|
||||||
|
|
||||||
|
|||||||
74
inference/tests/test_fp8_cast_bf16.py
Normal file
74
inference/tests/test_fp8_cast_bf16.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import tempfile
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import save_file, load_file
|
||||||
|
|
||||||
|
# Import the conversion API we expose for programmatic use
|
||||||
|
from inference.fp8_cast_bf16 import convert_fp8_to_bf16
|
||||||
|
|
||||||
|
|
||||||
|
def _make_block_scale(shape_blocks: Tuple[int, int], value: float, device: str) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Create a per-block scale tensor of shape (M_blocks, N_blocks) filled with a constant.
|
||||||
|
"""
|
||||||
|
return torch.full(shape_blocks, value, dtype=torch.float32, device=device).contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_fp8_to_bf16_cpu_roundtrip_small_matrix():
|
||||||
|
"""
|
||||||
|
Validate CPU fallback by constructing a tiny FP8 weight with known block scales,
|
||||||
|
converting to BF16, and checking the recovered values.
|
||||||
|
"""
|
||||||
|
if not hasattr(torch, "float8_e4m3fn"):
|
||||||
|
# Skip if PyTorch build lacks float8 support
|
||||||
|
return
|
||||||
|
|
||||||
|
device = "cpu"
|
||||||
|
block_size = 2
|
||||||
|
M, N = 4, 4
|
||||||
|
# Choose a uniform block scale that is easy to reason about
|
||||||
|
scale_value = 0.5 # multiplicative factor used during dequant
|
||||||
|
# Construct the target dequantized weights (what we want to recover)
|
||||||
|
y_true = torch.tensor(
|
||||||
|
[[1.0, 2.0, 3.0, 4.0],
|
||||||
|
[5.0, 6.0, 7.0, 8.0],
|
||||||
|
[2.0, 4.0, 6.0, 8.0],
|
||||||
|
[1.5, 2.5, 3.5, 4.5]],
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
# Create the per-block scale tensor: (M // block_size, N // block_size)
|
||||||
|
s = _make_block_scale((M // block_size, N // block_size), scale_value, device)
|
||||||
|
# Expand s to full resolution for constructing FP8 quantized weights
|
||||||
|
s_full = s.repeat_interleave(block_size, dim=0).repeat_interleave(block_size, dim=1)
|
||||||
|
# Build FP8 weights such that dequant (x * s_full) recovers y_true
|
||||||
|
x_fp32 = (y_true / scale_value).contiguous()
|
||||||
|
x_fp8 = x_fp32.to(torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
fp8_dir = os.path.join(tmp, "fp8")
|
||||||
|
bf16_dir = os.path.join(tmp, "bf16")
|
||||||
|
os.makedirs(fp8_dir, exist_ok=True)
|
||||||
|
os.makedirs(bf16_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Create minimal safetensors shard and index
|
||||||
|
shard = {"layer.weight": x_fp8, "layer.weight_scale_inv": s}
|
||||||
|
shard_name = "model-00001-of-00001.safetensors"
|
||||||
|
save_file(shard, os.path.join(fp8_dir, shard_name))
|
||||||
|
|
||||||
|
index = {"metadata": {}, "weight_map": {"layer.weight": shard_name, "layer.weight_scale_inv": shard_name}}
|
||||||
|
with open(os.path.join(fp8_dir, "model.safetensors.index.json"), "w") as f:
|
||||||
|
json.dump(index, f)
|
||||||
|
|
||||||
|
# Run conversion using CPU path and a small block size
|
||||||
|
convert_fp8_to_bf16(fp8_dir, bf16_dir, device="cpu", block_size=block_size)
|
||||||
|
|
||||||
|
# Load converted weights and verify they match the expected y_true (within tolerance)
|
||||||
|
out_shard = load_file(os.path.join(bf16_dir, shard_name), device=device)
|
||||||
|
y = out_shard["layer.weight"].to(torch.float32)
|
||||||
|
assert torch.allclose(y, y_true, atol=1e-2, rtol=1e-2)
|
||||||
|
|
||||||
|
|
||||||
Loading…
x
Reference in New Issue
Block a user