Merge 6847b3d86692438b63a831f7649705c18d7d333b into 9b4e9788e4a3a731f7567338ed15d3ec549ce03b

This commit is contained in:
Richard Ogundele 2025-09-26 22:19:39 +01:00 committed by GitHub
commit ab76fa01ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 131 additions and 8 deletions

View File

@ -246,6 +246,11 @@ cd inference
python fp8_cast_bf16.py --input-fp8-hf-path /path/to/fp8_weights --output-bf16-hf-path /path/to/bf16_weights python fp8_cast_bf16.py --input-fp8-hf-path /path/to/fp8_weights --output-bf16-hf-path /path/to/bf16_weights
``` ```
> [!TIP]
> The conversion script now supports optional arguments for broader portability:
> - `--device`: `auto` (default), `cuda`, or `cpu`. When set to `auto`, the script prefers CUDA if available, otherwise falls back to CPU. Use `cpu` to force a CPU-only path on systems without GPUs.
> - `--block-size`: Block size used during quantization/dequantization (default `128`). This should match the models tiling settings at export time.
> [!NOTE] > [!NOTE]
> Hugging Face's Transformers has not been directly supported yet. > Hugging Face's Transformers has not been directly supported yet.

View File

@ -2,14 +2,47 @@ import os
import json import json
from argparse import ArgumentParser from argparse import ArgumentParser
from glob import glob from glob import glob
from typing import Dict
from tqdm import tqdm from tqdm import tqdm
import torch import torch
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
from kernel import weight_dequant # Use the optimized Triton GPU kernel when available; add a CPU fallback below
from kernel import weight_dequant as weight_dequant_gpu
def main(fp8_path, bf16_path): def _infer_device(requested_device: str) -> str:
"""
Decide which device to use for tensor I/O and dequantization.
- "auto": prefer CUDA if available, otherwise CPU
- "cuda": use CUDA when available, else fall back to CPU
- "cpu": force CPU
"""
if requested_device == "cuda":
return "cuda" if torch.cuda.is_available() else "cpu"
if requested_device == "cpu":
return "cpu"
return "cuda" if torch.cuda.is_available() else "cpu"
def _weight_dequant_cpu(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
"""
CPU fallback for FP8 weight dequantization using pure PyTorch.
Expands per-block scales to a full resolution map and rescales FP8 weights.
Prioritizes correctness and portability; may be memory intensive for huge tensors.
"""
assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions"
M, N = x.shape
m_blocks, n_blocks = s.shape
assert M % block_size == 0 and N % block_size == 0, "Weight dims must be multiples of block_size"
assert m_blocks == M // block_size and n_blocks == N // block_size, "Scale shape must match weight tiling"
s_full = s.repeat_interleave(block_size, dim=0).repeat_interleave(block_size, dim=1)
y = x.to(torch.float32) * s_full
return y.to(torch.get_default_dtype())
def convert_fp8_to_bf16(fp8_path: str, bf16_path: str, device: str = "auto", block_size: int = 128):
""" """
Converts FP8 weights to BF16 and saves the converted weights. Converts FP8 weights to BF16 and saves the converted weights.
@ -20,6 +53,8 @@ def main(fp8_path, bf16_path):
Args: Args:
fp8_path (str): The path to the directory containing the FP8 weights and model index file. fp8_path (str): The path to the directory containing the FP8 weights and model index file.
bf16_path (str): The path to the directory where the converted BF16 weights will be saved. bf16_path (str): The path to the directory where the converted BF16 weights will be saved.
device (str): One of {"auto", "cuda", "cpu"}. Controls load/compute device.
block_size (int): Block size used by quantization/dequantization. Typically 128.
Raises: Raises:
KeyError: If a required scale_inv tensor is missing for a weight. KeyError: If a required scale_inv tensor is missing for a weight.
@ -31,13 +66,14 @@ def main(fp8_path, bf16_path):
""" """
torch.set_default_dtype(torch.bfloat16) torch.set_default_dtype(torch.bfloat16)
os.makedirs(bf16_path, exist_ok=True) os.makedirs(bf16_path, exist_ok=True)
effective_device = _infer_device(device)
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json") model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
with open(model_index_file, "r") as f: with open(model_index_file, "r") as f:
model_index = json.load(f) model_index = json.load(f)
weight_map = model_index["weight_map"] weight_map = model_index["weight_map"]
# Cache for loaded safetensor files # Cache for loaded safetensor files
loaded_files = {} loaded_files: Dict[str, Dict[str, torch.Tensor]] = {}
fp8_weight_names = [] fp8_weight_names = []
# Helper function to get tensor from the correct file # Helper function to get tensor from the correct file
@ -57,14 +93,14 @@ def main(fp8_path, bf16_path):
file_name = weight_map[tensor_name] file_name = weight_map[tensor_name]
if file_name not in loaded_files: if file_name not in loaded_files:
file_path = os.path.join(fp8_path, file_name) file_path = os.path.join(fp8_path, file_name)
loaded_files[file_name] = load_file(file_path, device="cuda") loaded_files[file_name] = load_file(file_path, device=effective_device)
return loaded_files[file_name][tensor_name] return loaded_files[file_name][tensor_name]
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors"))) safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
safetensor_files.sort() safetensor_files.sort()
for safetensor_file in tqdm(safetensor_files): for safetensor_file in tqdm(safetensor_files):
file_name = os.path.basename(safetensor_file) file_name = os.path.basename(safetensor_file)
current_state_dict = load_file(safetensor_file, device="cuda") current_state_dict = load_file(safetensor_file, device=effective_device)
loaded_files[file_name] = current_state_dict loaded_files[file_name] = current_state_dict
new_state_dict = {} new_state_dict = {}
@ -77,7 +113,10 @@ def main(fp8_path, bf16_path):
# Get scale_inv from the correct file # Get scale_inv from the correct file
scale_inv = get_tensor(scale_inv_name) scale_inv = get_tensor(scale_inv_name)
fp8_weight_names.append(weight_name) fp8_weight_names.append(weight_name)
new_state_dict[weight_name] = weight_dequant(weight, scale_inv) if effective_device == "cuda":
new_state_dict[weight_name] = weight_dequant_gpu(weight, scale_inv, block_size=block_size)
else:
new_state_dict[weight_name] = _weight_dequant_cpu(weight, scale_inv, block_size=block_size)
except KeyError: except KeyError:
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion") print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
new_state_dict[weight_name] = weight new_state_dict[weight_name] = weight
@ -91,7 +130,8 @@ def main(fp8_path, bf16_path):
if len(loaded_files) > 2: if len(loaded_files) > 2:
oldest_file = next(iter(loaded_files)) oldest_file = next(iter(loaded_files))
del loaded_files[oldest_file] del loaded_files[oldest_file]
torch.cuda.empty_cache() if effective_device == "cuda":
torch.cuda.empty_cache()
# Update model index # Update model index
new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json") new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
@ -107,6 +147,10 @@ if __name__ == "__main__":
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("--input-fp8-hf-path", type=str, required=True) parser.add_argument("--input-fp8-hf-path", type=str, required=True)
parser.add_argument("--output-bf16-hf-path", type=str, required=True) parser.add_argument("--output-bf16-hf-path", type=str, required=True)
parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu"],
help="Select compute device: 'auto' prefers CUDA if available; otherwise CPU.")
parser.add_argument("--block-size", type=int, default=128,
help="Block size used during quantization/dequantization. Typically 128.")
args = parser.parse_args() args = parser.parse_args()
main(args.input_fp8_hf_path, args.output_bf16_hf_path) convert_fp8_to_bf16(args.input_fp8_hf_path, args.output_bf16_hf_path, device=args.device, block_size=args.block_size)

View File

@ -0,0 +1,74 @@
import os
import json
import tempfile
from typing import Tuple
import torch
from safetensors.torch import save_file, load_file
# Import the conversion API we expose for programmatic use
from inference.fp8_cast_bf16 import convert_fp8_to_bf16
def _make_block_scale(shape_blocks: Tuple[int, int], value: float, device: str) -> torch.Tensor:
"""
Create a per-block scale tensor of shape (M_blocks, N_blocks) filled with a constant.
"""
return torch.full(shape_blocks, value, dtype=torch.float32, device=device).contiguous()
def test_convert_fp8_to_bf16_cpu_roundtrip_small_matrix():
"""
Validate CPU fallback by constructing a tiny FP8 weight with known block scales,
converting to BF16, and checking the recovered values.
"""
if not hasattr(torch, "float8_e4m3fn"):
# Skip if PyTorch build lacks float8 support
return
device = "cpu"
block_size = 2
M, N = 4, 4
# Choose a uniform block scale that is easy to reason about
scale_value = 0.5 # multiplicative factor used during dequant
# Construct the target dequantized weights (what we want to recover)
y_true = torch.tensor(
[[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
[2.0, 4.0, 6.0, 8.0],
[1.5, 2.5, 3.5, 4.5]],
dtype=torch.float32,
device=device,
)
# Create the per-block scale tensor: (M // block_size, N // block_size)
s = _make_block_scale((M // block_size, N // block_size), scale_value, device)
# Expand s to full resolution for constructing FP8 quantized weights
s_full = s.repeat_interleave(block_size, dim=0).repeat_interleave(block_size, dim=1)
# Build FP8 weights such that dequant (x * s_full) recovers y_true
x_fp32 = (y_true / scale_value).contiguous()
x_fp8 = x_fp32.to(torch.float8_e4m3fn)
with tempfile.TemporaryDirectory() as tmp:
fp8_dir = os.path.join(tmp, "fp8")
bf16_dir = os.path.join(tmp, "bf16")
os.makedirs(fp8_dir, exist_ok=True)
os.makedirs(bf16_dir, exist_ok=True)
# Create minimal safetensors shard and index
shard = {"layer.weight": x_fp8, "layer.weight_scale_inv": s}
shard_name = "model-00001-of-00001.safetensors"
save_file(shard, os.path.join(fp8_dir, shard_name))
index = {"metadata": {}, "weight_map": {"layer.weight": shard_name, "layer.weight_scale_inv": shard_name}}
with open(os.path.join(fp8_dir, "model.safetensors.index.json"), "w") as f:
json.dump(index, f)
# Run conversion using CPU path and a small block size
convert_fp8_to_bf16(fp8_dir, bf16_dir, device="cpu", block_size=block_size)
# Load converted weights and verify they match the expected y_true (within tolerance)
out_shard = load_file(os.path.join(bf16_dir, shard_name), device=device)
y = out_shard["layer.weight"].to(torch.float32)
assert torch.allclose(y, y_true, atol=1e-2, rtol=1e-2)