diff --git a/README.md b/README.md index e94a77f..3c02021 100644 --- a/README.md +++ b/README.md @@ -246,6 +246,11 @@ cd inference python fp8_cast_bf16.py --input-fp8-hf-path /path/to/fp8_weights --output-bf16-hf-path /path/to/bf16_weights ``` +> [!TIP] +> The conversion script now supports optional arguments for broader portability: +> - `--device`: `auto` (default), `cuda`, or `cpu`. When set to `auto`, the script prefers CUDA if available, otherwise falls back to CPU. Use `cpu` to force a CPU-only path on systems without GPUs. +> - `--block-size`: Block size used during quantization/dequantization (default `128`). This should match the model’s tiling settings at export time. + > [!NOTE] > Hugging Face's Transformers has not been directly supported yet. diff --git a/inference/fp8_cast_bf16.py b/inference/fp8_cast_bf16.py index 4037342..f8ac704 100644 --- a/inference/fp8_cast_bf16.py +++ b/inference/fp8_cast_bf16.py @@ -2,14 +2,47 @@ import os import json from argparse import ArgumentParser from glob import glob +from typing import Dict from tqdm import tqdm import torch from safetensors.torch import load_file, save_file -from kernel import weight_dequant +# Use the optimized Triton GPU kernel when available; add a CPU fallback below +from kernel import weight_dequant as weight_dequant_gpu -def main(fp8_path, bf16_path): +def _infer_device(requested_device: str) -> str: + """ + Decide which device to use for tensor I/O and dequantization. + + - "auto": prefer CUDA if available, otherwise CPU + - "cuda": use CUDA when available, else fall back to CPU + - "cpu": force CPU + """ + if requested_device == "cuda": + return "cuda" if torch.cuda.is_available() else "cpu" + if requested_device == "cpu": + return "cpu" + return "cuda" if torch.cuda.is_available() else "cpu" + + +def _weight_dequant_cpu(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor: + """ + CPU fallback for FP8 weight dequantization using pure PyTorch. + + Expands per-block scales to a full resolution map and rescales FP8 weights. + Prioritizes correctness and portability; may be memory intensive for huge tensors. + """ + assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions" + M, N = x.shape + m_blocks, n_blocks = s.shape + assert M % block_size == 0 and N % block_size == 0, "Weight dims must be multiples of block_size" + assert m_blocks == M // block_size and n_blocks == N // block_size, "Scale shape must match weight tiling" + s_full = s.repeat_interleave(block_size, dim=0).repeat_interleave(block_size, dim=1) + y = x.to(torch.float32) * s_full + return y.to(torch.get_default_dtype()) + +def convert_fp8_to_bf16(fp8_path: str, bf16_path: str, device: str = "auto", block_size: int = 128): """ Converts FP8 weights to BF16 and saves the converted weights. @@ -20,6 +53,8 @@ def main(fp8_path, bf16_path): Args: fp8_path (str): The path to the directory containing the FP8 weights and model index file. bf16_path (str): The path to the directory where the converted BF16 weights will be saved. + device (str): One of {"auto", "cuda", "cpu"}. Controls load/compute device. + block_size (int): Block size used by quantization/dequantization. Typically 128. Raises: KeyError: If a required scale_inv tensor is missing for a weight. @@ -31,13 +66,14 @@ def main(fp8_path, bf16_path): """ torch.set_default_dtype(torch.bfloat16) os.makedirs(bf16_path, exist_ok=True) + effective_device = _infer_device(device) model_index_file = os.path.join(fp8_path, "model.safetensors.index.json") with open(model_index_file, "r") as f: model_index = json.load(f) weight_map = model_index["weight_map"] # Cache for loaded safetensor files - loaded_files = {} + loaded_files: Dict[str, Dict[str, torch.Tensor]] = {} fp8_weight_names = [] # Helper function to get tensor from the correct file @@ -57,14 +93,14 @@ def main(fp8_path, bf16_path): file_name = weight_map[tensor_name] if file_name not in loaded_files: file_path = os.path.join(fp8_path, file_name) - loaded_files[file_name] = load_file(file_path, device="cuda") + loaded_files[file_name] = load_file(file_path, device=effective_device) return loaded_files[file_name][tensor_name] safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors"))) safetensor_files.sort() for safetensor_file in tqdm(safetensor_files): file_name = os.path.basename(safetensor_file) - current_state_dict = load_file(safetensor_file, device="cuda") + current_state_dict = load_file(safetensor_file, device=effective_device) loaded_files[file_name] = current_state_dict new_state_dict = {} @@ -77,7 +113,10 @@ def main(fp8_path, bf16_path): # Get scale_inv from the correct file scale_inv = get_tensor(scale_inv_name) fp8_weight_names.append(weight_name) - new_state_dict[weight_name] = weight_dequant(weight, scale_inv) + if effective_device == "cuda": + new_state_dict[weight_name] = weight_dequant_gpu(weight, scale_inv, block_size=block_size) + else: + new_state_dict[weight_name] = _weight_dequant_cpu(weight, scale_inv, block_size=block_size) except KeyError: print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion") new_state_dict[weight_name] = weight @@ -91,7 +130,8 @@ def main(fp8_path, bf16_path): if len(loaded_files) > 2: oldest_file = next(iter(loaded_files)) del loaded_files[oldest_file] - torch.cuda.empty_cache() + if effective_device == "cuda": + torch.cuda.empty_cache() # Update model index new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json") @@ -107,6 +147,10 @@ if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("--input-fp8-hf-path", type=str, required=True) parser.add_argument("--output-bf16-hf-path", type=str, required=True) + parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu"], + help="Select compute device: 'auto' prefers CUDA if available; otherwise CPU.") + parser.add_argument("--block-size", type=int, default=128, + help="Block size used during quantization/dequantization. Typically 128.") args = parser.parse_args() - main(args.input_fp8_hf_path, args.output_bf16_hf_path) + convert_fp8_to_bf16(args.input_fp8_hf_path, args.output_bf16_hf_path, device=args.device, block_size=args.block_size) diff --git a/inference/tests/test_fp8_cast_bf16.py b/inference/tests/test_fp8_cast_bf16.py new file mode 100644 index 0000000..bd3f32b --- /dev/null +++ b/inference/tests/test_fp8_cast_bf16.py @@ -0,0 +1,74 @@ +import os +import json +import tempfile +from typing import Tuple + +import torch +from safetensors.torch import save_file, load_file + +# Import the conversion API we expose for programmatic use +from inference.fp8_cast_bf16 import convert_fp8_to_bf16 + + +def _make_block_scale(shape_blocks: Tuple[int, int], value: float, device: str) -> torch.Tensor: + """ + Create a per-block scale tensor of shape (M_blocks, N_blocks) filled with a constant. + """ + return torch.full(shape_blocks, value, dtype=torch.float32, device=device).contiguous() + + +def test_convert_fp8_to_bf16_cpu_roundtrip_small_matrix(): + """ + Validate CPU fallback by constructing a tiny FP8 weight with known block scales, + converting to BF16, and checking the recovered values. + """ + if not hasattr(torch, "float8_e4m3fn"): + # Skip if PyTorch build lacks float8 support + return + + device = "cpu" + block_size = 2 + M, N = 4, 4 + # Choose a uniform block scale that is easy to reason about + scale_value = 0.5 # multiplicative factor used during dequant + # Construct the target dequantized weights (what we want to recover) + y_true = torch.tensor( + [[1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [2.0, 4.0, 6.0, 8.0], + [1.5, 2.5, 3.5, 4.5]], + dtype=torch.float32, + device=device, + ) + # Create the per-block scale tensor: (M // block_size, N // block_size) + s = _make_block_scale((M // block_size, N // block_size), scale_value, device) + # Expand s to full resolution for constructing FP8 quantized weights + s_full = s.repeat_interleave(block_size, dim=0).repeat_interleave(block_size, dim=1) + # Build FP8 weights such that dequant (x * s_full) recovers y_true + x_fp32 = (y_true / scale_value).contiguous() + x_fp8 = x_fp32.to(torch.float8_e4m3fn) + + with tempfile.TemporaryDirectory() as tmp: + fp8_dir = os.path.join(tmp, "fp8") + bf16_dir = os.path.join(tmp, "bf16") + os.makedirs(fp8_dir, exist_ok=True) + os.makedirs(bf16_dir, exist_ok=True) + + # Create minimal safetensors shard and index + shard = {"layer.weight": x_fp8, "layer.weight_scale_inv": s} + shard_name = "model-00001-of-00001.safetensors" + save_file(shard, os.path.join(fp8_dir, shard_name)) + + index = {"metadata": {}, "weight_map": {"layer.weight": shard_name, "layer.weight_scale_inv": shard_name}} + with open(os.path.join(fp8_dir, "model.safetensors.index.json"), "w") as f: + json.dump(index, f) + + # Run conversion using CPU path and a small block size + convert_fp8_to_bf16(fp8_dir, bf16_dir, device="cpu", block_size=block_size) + + # Load converted weights and verify they match the expected y_true (within tolerance) + out_shard = load_file(os.path.join(bf16_dir, shard_name), device=device) + y = out_shard["layer.weight"].to(torch.float32) + assert torch.allclose(y, y_true, atol=1e-2, rtol=1e-2) + +