mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:25:01 +08:00
77 lines
2.6 KiB
Python
77 lines
2.6 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import torch
|
|
|
|
from vllm._custom_ops import scaled_fp4_quant
|
|
from vllm.scalar_type import scalar_types
|
|
|
|
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
|
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
|
|
|
kE2M1ToFloat = torch.tensor(
|
|
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32
|
|
)
|
|
|
|
|
|
def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
|
|
m_tiles = (m + 128 - 1) // 128
|
|
f = block_size * 4
|
|
k_tiles = (k + f - 1) // f
|
|
tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4))
|
|
tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
|
|
out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size)
|
|
return out[0:m, 0:k]
|
|
|
|
|
|
def dequantize_nvfp4_to_dtype(
|
|
tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16
|
|
):
|
|
"""Dequantize the fp4 tensor back to high precision."""
|
|
# Two fp4 values are packed into one uint8.
|
|
assert tensor_fp4.dtype == torch.uint8
|
|
m, packed_k = tensor_fp4.shape
|
|
k = packed_k * 2
|
|
tensor_f32 = break_fp4_bytes(tensor_fp4, dtype)
|
|
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
|
|
tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
|
|
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
|
|
tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale
|
|
|
|
# scale the tensor
|
|
out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
|
|
return out.to(dtype=dtype)
|
|
|
|
|
|
def break_fp4_bytes(a, dtype):
|
|
assert a.dtype == torch.uint8
|
|
m, n = a.shape
|
|
|
|
# Vectorized nibble processing
|
|
a_flat = a.flatten()
|
|
high = (a_flat & 0xF0) >> 4 # Upper nibbles
|
|
low = a_flat & 0x0F # Lower nibbles
|
|
|
|
# Combine nibbles for batch processing
|
|
combined = torch.stack((low, high), dim=1).flatten()
|
|
|
|
# Vectorized sign and magnitude extraction
|
|
signs = (combined & 0x08).to(torch.bool) # Sign bits
|
|
abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices
|
|
|
|
# Device-aware lookup and sign application
|
|
kE2M1 = kE2M1ToFloat.to(device=a.device)
|
|
values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0)
|
|
|
|
# Reshape to final form
|
|
return values.reshape(m, n * 2).to(dtype=dtype)
|
|
|
|
|
|
def get_nvfp4_global_scale(a: torch.Tensor):
|
|
return (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(a).max().to(torch.float32)
|
|
|
|
|
|
def quant_nvfp4_tensor(a: torch.Tensor):
|
|
a_global_scale = get_nvfp4_global_scale(a)
|
|
a_quant, a_block_scale = scaled_fp4_quant(a, a_global_scale)
|
|
return a_quant, a_block_scale, a_global_scale
|