mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:06:10 +08:00
[Misc] Add compressed-tensors NVFP4A16 emulation support (#17914)
Signed-off-by: Dipika Sikka <dipikasikka1@gmail.com> Signed-off-by: Dipika <dipikasikka1@gmail.com>
This commit is contained in:
parent
9cea90eab4
commit
cd3edfc908
@ -13,9 +13,9 @@ from compressed_tensors.quantization import QuantizationType
|
||||
from tests.models.utils import check_logprobs_close
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||
CompressedTensors24, CompressedTensorsLinearMethod,
|
||||
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
|
||||
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
|
||||
CompressedTensorsWNA16)
|
||||
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24,
|
||||
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
|
||||
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
sparse_cutlass_supported)
|
||||
from vllm.platforms import current_platform
|
||||
@ -648,3 +648,23 @@ def test_compressed_tensors_2of4_sparse_compressed(vllm_runner, args_2of4):
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||
print(output)
|
||||
assert output
|
||||
|
||||
|
||||
def test_compressed_tensors_nvfp4a16(vllm_runner):
|
||||
# run weight only example
|
||||
model = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP4"
|
||||
with vllm_runner(model, enforce_eager=True) as llm:
|
||||
|
||||
def check_model(model):
|
||||
layer = model.model.layers[0]
|
||||
|
||||
qkv_proj = layer.self_attn.qkv_proj
|
||||
assert isinstance(qkv_proj.quant_method,
|
||||
CompressedTensorsLinearMethod)
|
||||
assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16Fp4)
|
||||
assert qkv_proj.scheme.group_size == 16
|
||||
|
||||
llm.apply_model(check_model)
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||
print(output)
|
||||
assert output
|
||||
|
||||
@ -23,9 +23,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
|
||||
CompressedTensorsMoEMethod)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24,
|
||||
CompressedTensorsScheme, CompressedTensorsW4A16Sparse24,
|
||||
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
|
||||
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
|
||||
CompressedTensorsScheme, CompressedTensorsW4A16Fp4,
|
||||
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
|
||||
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
|
||||
CompressedTensorsWNA16)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
find_matched_target, is_activation_quantization_format,
|
||||
should_ignore_layer)
|
||||
@ -216,6 +217,21 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
else:
|
||||
return False
|
||||
|
||||
def _is_fp4a16_nvfp4(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel):
|
||||
|
||||
is_weight_only = weight_quant is not None and input_quant is None
|
||||
is_group_quant = (
|
||||
weight_quant.strategy == QuantizationStrategy.GROUP.value)
|
||||
is_symmetric = weight_quant.symmetric
|
||||
|
||||
is_group_size_16 = weight_quant.group_size == 16
|
||||
is_float_type = weight_quant.type == QuantizationType.FLOAT
|
||||
is_4_bits = weight_quant.num_bits == 4
|
||||
|
||||
return (is_weight_only and is_group_quant and is_float_type
|
||||
and is_4_bits and is_group_size_16 and is_symmetric)
|
||||
|
||||
def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
|
||||
@ -315,6 +331,9 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
input_quant: BaseModel) -> "CompressedTensorsScheme":
|
||||
|
||||
# Detect If Mixed Precision
|
||||
if self._is_fp4a16_nvfp4(weight_quant, input_quant):
|
||||
return CompressedTensorsW4A16Fp4()
|
||||
|
||||
if self._is_wNa16_group_channel(weight_quant, input_quant):
|
||||
if (self.quant_format == CompressionFormat.marlin_24.value
|
||||
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
from .compressed_tensors_scheme import CompressedTensorsScheme
|
||||
from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS,
|
||||
CompressedTensorsW4A16Sparse24)
|
||||
from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4
|
||||
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
|
||||
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
|
||||
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
|
||||
@ -16,5 +17,5 @@ __all__ = [
|
||||
"CompressedTensorsW8A16Fp8", "CompressedTensorsW4A16Sparse24",
|
||||
"CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8",
|
||||
"WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS",
|
||||
"CompressedTensors24"
|
||||
"CompressedTensors24", "CompressedTensorsW4A16Fp4"
|
||||
]
|
||||
|
||||
@ -0,0 +1,107 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
|
||||
dequantize_to_dtype)
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
|
||||
__all__ = ["CompressedTensorsW4A16Fp4"]
|
||||
|
||||
|
||||
class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
|
||||
|
||||
def __init__(self):
|
||||
self.group_size = 16
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# dont restrict as emulations
|
||||
return 80
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: List[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
# Weight
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition // 2,
|
||||
dtype=torch.uint8),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight_packed", weight)
|
||||
|
||||
# Global Weight Scale
|
||||
weight_global_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight_global_scale", weight_global_scale)
|
||||
|
||||
# Per Group Weight Scale
|
||||
weight_scale = GroupQuantScaleParameter(data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition // self.group_size,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def swizzle_blockscale(self, scale: torch.tensor):
|
||||
assert (scale.dtype == torch.float8_e4m3fn)
|
||||
# Pad and blockwise interleave weight_scale
|
||||
scale_ndim = scale.ndim
|
||||
if scale.ndim == 2:
|
||||
scale = scale.unsqueeze(0)
|
||||
assert scale.ndim == 3
|
||||
B, M, K = scale.shape
|
||||
round_up_multiple = lambda x, m: (x + m - 1) // m * m
|
||||
M_padded = round_up_multiple(M, 128)
|
||||
K_padded = round_up_multiple(K, 4)
|
||||
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
|
||||
padded_scale[:B, :M, :K] = scale
|
||||
batches, rows, cols = padded_scale.shape
|
||||
assert rows % 128 == 0
|
||||
assert cols % 4 == 0
|
||||
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
|
||||
cols // 4, 4)
|
||||
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
|
||||
swizzled_scale = swizzled_scale.contiguous().cuda()
|
||||
return (swizzled_scale.reshape(M, K)
|
||||
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
|
||||
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
layer.weight_global_scale = Parameter(
|
||||
layer.weight_global_scale.max().to(torch.float32),
|
||||
requires_grad=False)
|
||||
# Note: a post weight loading step but not required for the emulation
|
||||
swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale)
|
||||
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
|
||||
requires_grad=False)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
w_fp4 = layer.weight_packed.data
|
||||
w_global_scale = layer.weight_global_scale
|
||||
w_blockscale = layer.weight_scale_swizzled.data
|
||||
w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale,
|
||||
x.dtype, x.device, self.group_size)
|
||||
out = F.linear(x, w_dq)
|
||||
del w_dq, w_fp4, w_global_scale, w_blockscale
|
||||
return out
|
||||
@ -0,0 +1,61 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import torch
|
||||
|
||||
__all__ = [
|
||||
"break_fp4_bytes",
|
||||
"dequantize_to_dtype",
|
||||
]
|
||||
|
||||
kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.],
|
||||
dtype=torch.float32)
|
||||
|
||||
|
||||
def break_fp4_bytes(a, dtype):
|
||||
assert a.dtype == torch.uint8
|
||||
m, n = a.shape
|
||||
# Vectorized nibble processing
|
||||
a_flat = a.flatten()
|
||||
high = (a_flat & 0xF0) >> 4 # Upper nibbles
|
||||
low = a_flat & 0x0F # Lower nibbles
|
||||
# Combine nibbles for batch processing
|
||||
combined = torch.stack((low, high), dim=1).flatten()
|
||||
# Vectorized sign and magnitude extraction
|
||||
signs = (combined & 0x08).to(torch.bool) # Sign bits
|
||||
abs_vals = (combined & 0x07).to(torch.long)
|
||||
# Device-aware lookup and sign application
|
||||
kE2M1 = kE2M1ToFloat.to(device=a.device)
|
||||
values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0)
|
||||
# Reshape to final form
|
||||
return values.reshape(m, n * 2).to(dtype=dtype)
|
||||
|
||||
|
||||
def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
|
||||
m_tiles = (m + 128 - 1) // 128
|
||||
f = block_size * 4
|
||||
k_tiles = (k + f - 1) // f
|
||||
tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4))
|
||||
tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
|
||||
out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size)
|
||||
return out[0:m, 0:k]
|
||||
|
||||
|
||||
def dequantize_to_dtype(tensor_fp4,
|
||||
tensor_sf,
|
||||
global_scale,
|
||||
dtype,
|
||||
device,
|
||||
block_size=16):
|
||||
"""Dequantize the fp4 tensor back to high precision."""
|
||||
# Two fp4 values are packed into one uint8.
|
||||
assert tensor_fp4.dtype == torch.uint8
|
||||
m, packed_k = tensor_fp4.shape
|
||||
k = packed_k * 2
|
||||
tensor_f32 = break_fp4_bytes(tensor_fp4, torch.float32)
|
||||
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
|
||||
tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
|
||||
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
|
||||
tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale
|
||||
|
||||
# scale the tensor
|
||||
out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
|
||||
return out.to(dtype)
|
||||
Loading…
x
Reference in New Issue
Block a user