mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 12:05:48 +08:00
[Quantization] Remove FP4 emulation; Fall-back to marlin for device < 100 (#19563)
This commit is contained in:
parent
90f9c2eb5c
commit
6bc7b57315
@ -667,7 +667,13 @@ def test_compressed_tensors_nvfp4(vllm_runner, args):
|
|||||||
qkv_proj = layer.self_attn.qkv_proj
|
qkv_proj = layer.self_attn.qkv_proj
|
||||||
assert isinstance(qkv_proj.quant_method,
|
assert isinstance(qkv_proj.quant_method,
|
||||||
CompressedTensorsLinearMethod)
|
CompressedTensorsLinearMethod)
|
||||||
assert isinstance(qkv_proj.scheme, scheme)
|
if isinstance(qkv_proj.scheme, scheme) or isinstance(
|
||||||
|
qkv_proj.scheme, CompressedTensorsW4A16Fp4
|
||||||
|
) and not CompressedTensorsW4A4Fp4.cutlass_fp4_supported():
|
||||||
|
assert True
|
||||||
|
else:
|
||||||
|
raise AssertionError("FP4 Scheme Mismatch")
|
||||||
|
|
||||||
assert qkv_proj.scheme.group_size == 16
|
assert qkv_proj.scheme.group_size == 16
|
||||||
|
|
||||||
llm.apply_model(check_model)
|
llm.apply_model(check_model)
|
||||||
|
|||||||
@ -374,7 +374,14 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
|
|
||||||
if is_activation_quantization_format(self.quant_format):
|
if is_activation_quantization_format(self.quant_format):
|
||||||
if self._is_fp4a4_nvfp4(weight_quant, input_quant):
|
if self._is_fp4a4_nvfp4(weight_quant, input_quant):
|
||||||
return CompressedTensorsW4A4Fp4()
|
if CompressedTensorsW4A4Fp4.cutlass_fp4_supported():
|
||||||
|
return CompressedTensorsW4A4Fp4()
|
||||||
|
else:
|
||||||
|
logger.warning_once(
|
||||||
|
"Current platform does not support cutlass NVFP4."
|
||||||
|
" Running CompressedTensorsW4A16Fp4.")
|
||||||
|
return CompressedTensorsW4A16Fp4(
|
||||||
|
has_input_global_scale=True)
|
||||||
|
|
||||||
if self._is_fp8_w8a8(weight_quant, input_quant):
|
if self._is_fp8_w8a8(weight_quant, input_quant):
|
||||||
is_fp8_w8a8_supported = self._check_scheme_supported(
|
is_fp8_w8a8_supported = self._check_scheme_supported(
|
||||||
|
|||||||
@ -18,7 +18,8 @@ __all__ = ["CompressedTensorsW4A16Fp4"]
|
|||||||
|
|
||||||
class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
|
class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, has_input_global_scale: bool = False):
|
||||||
|
self.has_input_global_scale = has_input_global_scale
|
||||||
self.group_size = 16
|
self.group_size = 16
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -64,6 +65,13 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
|
|||||||
|
|
||||||
layer.register_parameter("weight_scale", weight_scale)
|
layer.register_parameter("weight_scale", weight_scale)
|
||||||
|
|
||||||
|
if self.has_input_global_scale:
|
||||||
|
input_global_scale = PerTensorScaleParameter(
|
||||||
|
data=torch.empty(len(output_partition_sizes),
|
||||||
|
dtype=torch.float32),
|
||||||
|
weight_loader=weight_loader)
|
||||||
|
layer.register_parameter("input_global_scale", input_global_scale)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer) -> None:
|
def process_weights_after_loading(self, layer) -> None:
|
||||||
# Process parameters for marlin repacking
|
# Process parameters for marlin repacking
|
||||||
|
|
||||||
@ -77,6 +85,10 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
|
|||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
del layer.weight_global_scale
|
del layer.weight_global_scale
|
||||||
|
|
||||||
|
if self.has_input_global_scale:
|
||||||
|
layer.input_global_scale = torch.nn.Parameter(
|
||||||
|
layer.input_global_scale.data, requires_grad=False)
|
||||||
|
|
||||||
prepare_fp4_layer_for_marlin(layer)
|
prepare_fp4_layer_for_marlin(layer)
|
||||||
|
|
||||||
def apply_weights(self,
|
def apply_weights(self,
|
||||||
|
|||||||
@ -9,8 +9,6 @@ from vllm._custom_ops import (cutlass_scaled_fp4_mm,
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
CompressedTensorsScheme)
|
CompressedTensorsScheme)
|
||||||
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
|
|
||||||
dequantize_to_dtype, ref_nvfp4_quant)
|
|
||||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||||
ModelWeightParameter,
|
ModelWeightParameter,
|
||||||
PerTensorScaleParameter)
|
PerTensorScaleParameter)
|
||||||
@ -21,53 +19,23 @@ logger = init_logger(__name__)
|
|||||||
__all__ = ["CompressedTensorsW4A4Fp4"]
|
__all__ = ["CompressedTensorsW4A4Fp4"]
|
||||||
|
|
||||||
|
|
||||||
def cutlass_fp4_supported() -> bool:
|
|
||||||
if not current_platform.is_cuda():
|
|
||||||
return False
|
|
||||||
capability_tuple = current_platform.get_device_capability()
|
|
||||||
capability = -1 if capability_tuple is None else capability_tuple.to_int()
|
|
||||||
return cutlass_scaled_mm_supports_fp4(capability)
|
|
||||||
|
|
||||||
|
|
||||||
class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.group_size = 16
|
self.group_size = 16
|
||||||
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
|
|
||||||
if not self.cutlass_nvfp4_supported:
|
|
||||||
logger.warning("Current platform does not support cutlass NVFP4."
|
|
||||||
" Running emulations.")
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
# dont restrict as emulations
|
return 100
|
||||||
return 80
|
|
||||||
|
|
||||||
def run_nvfp4_emulations(self, x: torch.Tensor, layer):
|
@classmethod
|
||||||
x_m, x_k = x.shape
|
def cutlass_fp4_supported(cls) -> bool:
|
||||||
output_dtype = x.dtype
|
if not current_platform.is_cuda():
|
||||||
|
return False
|
||||||
# quantize input to (FP4 and interleaved block scale)
|
capability_tuple = current_platform.get_device_capability()
|
||||||
x_fp4, x_blockscale = ref_nvfp4_quant(x, layer.input_global_scale,
|
capability = -1 if capability_tuple is None else capability_tuple.to_int( # noqa: E501
|
||||||
self.group_size)
|
)
|
||||||
|
return cutlass_scaled_mm_supports_fp4(capability)
|
||||||
# dequantize input
|
|
||||||
x_fp4 = x_fp4.reshape(x_m, x_k // self.group_size, self.group_size)
|
|
||||||
x_blockscale = x_blockscale.unsqueeze(-1) / layer.input_global_scale
|
|
||||||
x_dq = (x_fp4 * x_blockscale).reshape(x_m, x_k).to(output_dtype)
|
|
||||||
del x_fp4, x_blockscale
|
|
||||||
|
|
||||||
# dequantize weight
|
|
||||||
w_fp4 = layer.weight.data.view(torch.uint8)
|
|
||||||
w_blockscale = layer.weight_scale_swizzled.data
|
|
||||||
w_global_scale = layer.weight_global_scale
|
|
||||||
w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale,
|
|
||||||
output_dtype, x.device, self.group_size)
|
|
||||||
|
|
||||||
# matmul
|
|
||||||
out = torch.matmul(x_dq, w_dq.t())
|
|
||||||
del w_dq, x_dq
|
|
||||||
return out
|
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module,
|
def create_weights(self, layer: torch.nn.Module,
|
||||||
output_partition_sizes: list[int],
|
output_partition_sizes: list[int],
|
||||||
@ -152,27 +120,24 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
|||||||
# required by cutlass kernel; need Parameter, not ModelWeightParameter
|
# required by cutlass kernel; need Parameter, not ModelWeightParameter
|
||||||
layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)
|
layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)
|
||||||
|
|
||||||
if self.cutlass_nvfp4_supported:
|
layer.alpha = Parameter(layer.input_global_scale *
|
||||||
layer.alpha = Parameter(layer.input_global_scale *
|
layer.weight_global_scale,
|
||||||
layer.weight_global_scale,
|
requires_grad=False)
|
||||||
requires_grad=False)
|
|
||||||
|
|
||||||
def apply_weights(self,
|
def apply_weights(self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
|
||||||
if self.cutlass_nvfp4_supported:
|
output_dtype = x.dtype
|
||||||
output_dtype = x.dtype
|
output_shape = [x.shape[0], layer.weight.shape[0]]
|
||||||
output_shape = [x.shape[0], layer.weight.shape[0]]
|
|
||||||
|
|
||||||
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
|
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
|
||||||
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale)
|
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale)
|
||||||
|
|
||||||
out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale,
|
out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale,
|
||||||
layer.weight_scale_swizzled,
|
layer.weight_scale_swizzled,
|
||||||
1 / layer.alpha, output_dtype)
|
1 / layer.alpha, output_dtype)
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
out = out + bias
|
out = out + bias
|
||||||
return out.view(*output_shape)
|
return out.view(*output_shape)
|
||||||
return self.run_nvfp4_emulations(x, layer)
|
|
||||||
|
|||||||
@ -102,3 +102,32 @@ def ref_nvfp4_quant(x, global_scale, block_size):
|
|||||||
clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n)
|
clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n)
|
||||||
# both outputs are float32
|
# both outputs are float32
|
||||||
return cast_to_fp4(clipped_x), scale.squeeze(-1)
|
return cast_to_fp4(clipped_x), scale.squeeze(-1)
|
||||||
|
|
||||||
|
|
||||||
|
def run_nvfp4_emulations(x: torch.Tensor, input_global_scale: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale_swizzled: torch.Tensor,
|
||||||
|
weight_global_scale: torch.Tensor):
|
||||||
|
group_size = 16
|
||||||
|
x_m, x_k = x.shape
|
||||||
|
output_dtype = x.dtype
|
||||||
|
|
||||||
|
# quantize input to (FP4 and interleaved block scale)
|
||||||
|
x_fp4, x_blockscale = ref_nvfp4_quant(x, input_global_scale, group_size)
|
||||||
|
|
||||||
|
# dequantize input
|
||||||
|
x_fp4 = x_fp4.reshape(x_m, x_k // group_size, group_size)
|
||||||
|
x_blockscale = x_blockscale.unsqueeze(-1) / input_global_scale
|
||||||
|
x_dq = (x_fp4 * x_blockscale).reshape(x_m, x_k).to(output_dtype)
|
||||||
|
del x_fp4, x_blockscale
|
||||||
|
|
||||||
|
# dequantize weight
|
||||||
|
w_fp4 = weight.data.view(torch.uint8)
|
||||||
|
w_dq = dequantize_to_dtype(w_fp4, weight_scale_swizzled.data,
|
||||||
|
weight_global_scale, output_dtype, x.device,
|
||||||
|
group_size)
|
||||||
|
|
||||||
|
# matmul
|
||||||
|
out = torch.matmul(x_dq, w_dq.t())
|
||||||
|
del w_dq, x_dq
|
||||||
|
return out
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user