diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index cd0513652097..5bf96398bc71 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -805,12 +805,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear): assert loaded_shard_id < len(self.output_sizes) if isinstance(param, BlockQuantScaleParameter): - from vllm.model_executor.layers.quantization.fp8 import ( - Fp8LinearMethod, Fp8MoEMethod) assert self.quant_method is not None - assert isinstance(self.quant_method, - (Fp8LinearMethod, Fp8MoEMethod)) - weight_block_size = self.quant_method.quant_config.weight_block_size + # Assume the weight block size has been set by quant method + assert hasattr(self, "weight_block_size") + weight_block_size = self.weight_block_size assert weight_block_size is not None block_n, _ = weight_block_size[0], weight_block_size[1] shard_offset = ( @@ -989,8 +987,10 @@ class QKVParallelLinear(ColumnParallelLinear): # Note(simon): This is needed for Qwen3's fp8 quantization. if isinstance(param, BlockQuantScaleParameter): assert self.quant_method is not None - assert hasattr(self.quant_method, "quant_config") - weight_block_size = self.quant_method.quant_config.weight_block_size + # Assume the weight block size has been set by quant method + assert hasattr(self, "weight_block_size") + weight_block_size = self.weight_block_size + assert weight_block_size is not None block_n, _ = weight_block_size[0], weight_block_size[1] shard_offset = (shard_offset + block_n - 1) // block_n shard_size = (shard_size + block_n - 1) // block_n diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index b56a69131177..d6550dd16892 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -12,7 +12,6 @@ from compressed_tensors.quantization import (QuantizationArgs, QuantizationStrategy, QuantizationType) from compressed_tensors.transform import TransformConfig -from pydantic import BaseModel import vllm.envs as envs from vllm.logger import init_logger @@ -268,7 +267,8 @@ class CompressedTensorsConfig(QuantizationConfig): else: return False - def _is_fp4a4_nvfp4(self, weight_quant: BaseModel, input_quant: BaseModel): + def _is_fp4a4_nvfp4(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs): if weight_quant is None or input_quant is None: return False @@ -288,8 +288,8 @@ class CompressedTensorsConfig(QuantizationConfig): return (is_tensor_group_quant and is_float_type and is_4_bits and is_group_size_16 and is_symmetric) - def _is_fp4a16_nvfp4(self, weight_quant: BaseModel, - input_quant: BaseModel): + def _is_fp4a16_nvfp4(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs): is_weight_only = weight_quant is not None and input_quant is None is_tensor_group_quant = ( @@ -303,8 +303,8 @@ class CompressedTensorsConfig(QuantizationConfig): return (is_weight_only and is_tensor_group_quant and is_float_type and is_4_bits and is_group_size_16 and is_symmetric) - def _is_static_tensor_w8a8(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_static_tensor_w8a8(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs) -> bool: is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 weight_strategy = ( weight_quant.strategy == QuantizationStrategy.TENSOR.value @@ -317,8 +317,8 @@ class CompressedTensorsConfig(QuantizationConfig): # Only symmetric weight quantization supported. return is_8_bits and is_tensor and weight_quant.symmetric and is_static - def _is_dynamic_token_w8a8(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_dynamic_token_w8a8(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs) -> bool: is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 weight_strategy = ( weight_quant.strategy == QuantizationStrategy.TENSOR.value @@ -331,8 +331,8 @@ class CompressedTensorsConfig(QuantizationConfig): # Only symmetric weight quantization supported. return is_8_bits and is_token and weight_quant.symmetric and is_dynamic - def _is_dynamic_token_w4a8_int(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_dynamic_token_w4a8_int(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs) -> bool: is_weight_4_bits = weight_quant.num_bits == 4 is_activation_8_bits = input_quant.num_bits == 8 weight_strategy = ( @@ -347,8 +347,8 @@ class CompressedTensorsConfig(QuantizationConfig): return (is_weight_4_bits and is_activation_8_bits and is_token and weight_quant.symmetric and is_dynamic) - def _is_fp8_w8a8(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_fp8_w8a8(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs) -> bool: # Confirm weights and activations quantized. if weight_quant is None or input_quant is None: return False @@ -358,11 +358,12 @@ class CompressedTensorsConfig(QuantizationConfig): and input_quant.type == QuantizationType.FLOAT) is_symmetric_weight = weight_quant.symmetric is_static_weight = not weight_quant.dynamic - is_per_tensor_or_channel_weight = (weight_quant.strategy in [ - QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL + is_tensor_or_channel_or_block_weight = (weight_quant.strategy in [ + QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL, + QuantizationStrategy.BLOCK ]) if not (is_floating_point and is_symmetric_weight and is_static_weight - and is_per_tensor_or_channel_weight): + and is_tensor_or_channel_or_block_weight): return False # Dynamic quantization is always supported if weights supported. @@ -375,8 +376,8 @@ class CompressedTensorsConfig(QuantizationConfig): input_quant.strategy == QuantizationStrategy.TENSOR) return is_symmetric_activation and is_per_tensor_activation - def _is_fp8_w4a8(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_fp8_w4a8(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs) -> bool: if not weight_quant or not input_quant: return False is_weight_4_bits = weight_quant.num_bits == 4 @@ -392,24 +393,24 @@ class CompressedTensorsConfig(QuantizationConfig): return (is_weight_4_bits and is_activation_8_bits and is_token and is_symmetric and is_dynamic) - def _is_fp8_w4a8_sm90(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_fp8_w4a8_sm90(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs) -> bool: return (self._check_scheme_supported(90, error=False, match_exact=True) and self._is_fp8_w4a8(weight_quant, input_quant)) - def _is_fp8_w8a8_sm90(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_fp8_w8a8_sm90(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs) -> bool: return (self._check_scheme_supported(90, error=False, match_exact=True) and self._is_fp8_w8a8(weight_quant, input_quant)) - def _is_fp8_w8a8_sm100(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_fp8_w8a8_sm100(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs) -> bool: return (self._check_scheme_supported( 100, error=False, match_exact=True) and self._is_fp8_w8a8(weight_quant, input_quant)) - def _is_fp8_w8a16(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_fp8_w8a16(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs) -> bool: # Confirm weights quantized. if weight_quant is None: return False @@ -421,18 +422,19 @@ class CompressedTensorsConfig(QuantizationConfig): # Confirm weight scheme is supported. is_symmetric_weight = weight_quant.symmetric is_static_weight = not weight_quant.dynamic - is_per_tensor_or_channel_weight = (weight_quant.strategy in [ - QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL + is_tensor_or_channel_or_block_weight = (weight_quant.strategy in [ + QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL, + QuantizationStrategy.BLOCK ]) if not (is_symmetric_weight and is_static_weight # noqa: SIM103 - and is_per_tensor_or_channel_weight): + and is_tensor_or_channel_or_block_weight): return False # All conditions satisfied. return True - def _is_wNa16_group_channel(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_wNa16_group_channel(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs) -> bool: input_quant_none = input_quant is None is_channel_group = ( weight_quant.strategy == QuantizationStrategy.CHANNEL.value @@ -443,8 +445,8 @@ class CompressedTensorsConfig(QuantizationConfig): def _get_scheme_from_parts( self, - weight_quant: BaseModel, - input_quant: BaseModel, + weight_quant: QuantizationArgs, + input_quant: QuantizationArgs, format: Optional[str] = None) -> "CompressedTensorsScheme": # use the per-layer format if defined, otherwise, use global format @@ -496,7 +498,7 @@ class CompressedTensorsConfig(QuantizationConfig): CompressedTensorsW8A8Fp8.get_min_capability(), error=False) if is_fp8_w8a8_supported: return CompressedTensorsW8A8Fp8( - strategy=weight_quant.strategy, + weight_quant=weight_quant, is_static_input_scheme=(input_quant and not input_quant.dynamic)) else: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index d984e89d9e02..d42ae22c5139 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -4,28 +4,41 @@ from typing import Callable, Optional import torch -from compressed_tensors.quantization import QuantizationStrategy +from compressed_tensors.quantization import (QuantizationArgs, + QuantizationStrategy) from torch.nn import Parameter from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + apply_fp8_block_linear, check_aiter_fp8_linear_support, + create_fp8_input_scale, create_fp8_scale_parameter, + create_fp8_weight_parameter, maybe_post_process_fp8_weight_block, + process_fp8_weight_block_strategy, process_fp8_weight_channel_strategy, + process_fp8_weight_tensor_strategy, validate_fp8_block_shape) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz, - requantize_with_max_scale) -from vllm.model_executor.parameter import (ChannelQuantScaleParameter, - ModelWeightParameter, + Fp8LinearOp, cutlass_block_fp8_supported, maybe_create_device_identity) +from vllm.model_executor.parameter import (BlockQuantScaleParameter, + ChannelQuantScaleParameter, PerTensorScaleParameter) -from vllm.platforms import current_platform __all__ = ["CompressedTensorsW8A8Fp8"] +strategy_to_parameter_type = { + QuantizationStrategy.BLOCK: BlockQuantScaleParameter, + QuantizationStrategy.CHANNEL: ChannelQuantScaleParameter, + QuantizationStrategy.TENSOR: PerTensorScaleParameter, +} + class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): - def __init__(self, strategy: str, is_static_input_scheme: bool): - self.strategy = strategy + def __init__(self, weight_quant: QuantizationArgs, + is_static_input_scheme: bool): + self.weight_quant = weight_quant + self.strategy = weight_quant.strategy self.out_dtype = torch.get_default_dtype() self.is_static_input_scheme = is_static_input_scheme self.act_q_group_shape = GroupShape.PER_TENSOR \ @@ -34,61 +47,84 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): act_quant_static=self.is_static_input_scheme, act_quant_group_shape=self.act_q_group_shape) + self.weight_block_size = self.weight_quant.block_structure + self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() + self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() + @classmethod def get_min_capability(cls) -> int: # lovelace and up return 89 + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + weight_loader: Callable, **kwargs): + maybe_create_device_identity() + + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + layer.weight_block_size = None + + if self.strategy == QuantizationStrategy.BLOCK: + assert self.weight_block_size is not None + layer.weight_block_size = self.weight_block_size + # Validate block quantization shapes + validate_fp8_block_shape(layer, input_size, output_size, + input_size_per_partition, + output_partition_sizes, + self.weight_block_size) + + # WEIGHT + weight = create_fp8_weight_parameter(output_size_per_partition, + input_size_per_partition, + weight_loader) + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + weight_scale = create_fp8_scale_parameter( + strategy_to_parameter_type[self.strategy], output_partition_sizes, + input_size_per_partition, layer.weight_block_size, weight_loader) + layer.register_parameter("weight_scale", weight_scale) + + # INPUT SCALE + if self.is_static_input_scheme: + input_scale = create_fp8_input_scale(output_partition_sizes, + weight_loader) + layer.register_parameter("input_scale", input_scale) + def process_weights_after_loading(self, layer) -> None: - # If per tensor, when we have a fused module (e.g. QKV) with per - # tensor scales (thus N scales being passed to the kernel), - # requantize so we can always run per tensor if self.strategy == QuantizationStrategy.TENSOR: - max_w_scale, weight = requantize_with_max_scale( - weight=layer.weight, - weight_scale=layer.weight_scale, - logical_widths=layer.logical_widths, - ) + weight, weight_scale, input_scale = ( + process_fp8_weight_tensor_strategy( + layer.weight, layer.weight_scale, layer.logical_widths, + getattr(layer, 'input_scale', None))) + weight = weight.t() - if current_platform.is_fp8_fnuz(): - input_scale = getattr(layer, 'input_scale', None) - - weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( - weight=weight, - weight_scale=max_w_scale, - input_scale=input_scale) - if input_scale is not None: - layer.input_scale = Parameter(input_scale, - requires_grad=False) - - layer.weight = Parameter(weight.t(), requires_grad=False) - layer.weight_scale = Parameter(max_w_scale, requires_grad=False) - - # If channelwise, scales are already lined up, so just transpose. elif self.strategy == QuantizationStrategy.CHANNEL: - weight = layer.weight + weight, weight_scale, input_scale = ( + process_fp8_weight_channel_strategy( + layer.weight, layer.weight_scale, + getattr(layer, 'input_scale', None))) + weight = weight.t() - if current_platform.is_fp8_fnuz(): - input_scale = getattr(layer, 'input_scale', None) - - weight, weight_scale, input_scale = \ - normalize_e4m3fn_to_e4m3fnuz( - weight=weight, - weight_scale=layer.weight_scale, - input_scale=input_scale) - if input_scale is not None: - layer.input_scale = Parameter(input_scale, - requires_grad=False) - else: - weight_scale = layer.weight_scale.data - - layer.weight = Parameter(weight.t(), requires_grad=False) - # required by torch.compile to be torch.nn.Parameter - layer.weight_scale = Parameter(weight_scale, requires_grad=False) + elif self.strategy == QuantizationStrategy.BLOCK: + assert self.is_static_input_scheme is False + weight, weight_scale = process_fp8_weight_block_strategy( + layer.weight, layer.weight_scale) + input_scale = None else: raise ValueError(f"Unknown quantization strategy {self.strategy}") + # required by torch.compile to be torch.nn.Parameter + layer.weight = Parameter(weight.data, requires_grad=False) + layer.weight_scale = Parameter(weight_scale.data, requires_grad=False) + if input_scale is not None: + layer.input_scale = Parameter(input_scale.data, + requires_grad=False) + # INPUT SCALE if self.is_static_input_scheme and hasattr(layer, 'input_scale'): layer.input_scale = Parameter(layer.input_scale.max(), @@ -96,58 +132,23 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): else: layer.input_scale = None - def create_weights(self, layer: torch.nn.Module, - output_partition_sizes: list[int], - input_size_per_partition: int, - params_dtype: torch.dtype, weight_loader: Callable, - **kwargs): - maybe_create_device_identity() - - output_size_per_partition = sum(output_partition_sizes) - layer.logical_widths = output_partition_sizes - - # WEIGHT - weight = ModelWeightParameter(data=torch.empty( - output_size_per_partition, - input_size_per_partition, - dtype=torch.float8_e4m3fn), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) - layer.register_parameter("weight", weight) - - # WEIGHT SCALE - # TODO: update create_xxx_parameter functions to return - # the newly added parameters - if self.strategy == QuantizationStrategy.CHANNEL: - weight_scale = ChannelQuantScaleParameter( - data=torch.empty((sum(output_partition_sizes), 1), - dtype=torch.float32), - output_dim=0, - weight_loader=weight_loader) - else: - assert self.strategy == QuantizationStrategy.TENSOR - weight_scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) - - # min requirement for fp8 kernels - weight_scale[:] = torch.finfo(torch.float32).min - layer.register_parameter("weight_scale", weight_scale) - - # INPUT SCALE - if self.is_static_input_scheme: - input_scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) - input_scale[:] = torch.finfo(torch.float32).min - layer.register_parameter("input_scale", input_scale) + if self.strategy == QuantizationStrategy.BLOCK: + maybe_post_process_fp8_weight_block( + layer, self.cutlass_block_fp8_supported) def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + if layer.weight_block_size is not None: + return apply_fp8_block_linear( + layer, + input=x, + bias=bias, + cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, + use_aiter_and_is_supported=self.use_aiter_and_is_supported) + return self.fp8_linear.apply(input=x, weight=layer.weight, weight_scale=layer.weight_scale, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index e75094c54743..aec9c79f1ea8 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -4,7 +4,6 @@ from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch -import torch.nn.functional as F from torch.nn import Module from torch.nn.parameter import Parameter @@ -32,8 +31,12 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, select_cutlass_fp8_gemm_impl, swap_w13_to_w31) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace, - should_use_deepgemm_for_fp8_linear) + apply_fp8_block_linear, check_aiter_fp8_linear_support, + create_fp8_input_scale, create_fp8_scale_parameter, + create_fp8_weight_parameter, get_col_major_tma_aligned_tensor, + maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy, + process_fp8_weight_tensor_strategy, requant_weight_ue8m0_inplace, + validate_fp8_block_shape) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin, prepare_moe_fp8_layer_for_marlin) @@ -42,8 +45,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported, cutlass_fp8_supported, maybe_create_device_identity, - normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize, - requantize_with_max_scale) + normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) from vllm.model_executor.parameter import (BlockQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) @@ -233,14 +235,10 @@ class Fp8LinearMethod(LinearMethodBase): if current_platform.is_rocm(): self.use_marlin = False - # AITER is only supported on ROCm and only for FP8_FNUZ - # and at the moment are MI300 series - self.use_aiter_and_is_supported = (current_platform.is_rocm() - and envs.VLLM_ROCM_USE_AITER - and envs.VLLM_ROCM_USE_AITER_LINEAR - and current_platform.is_fp8_fnuz()) + self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() - self.block_quant = self.quant_config.weight_block_size is not None + self.weight_block_size = self.quant_config.weight_block_size + self.block_quant = self.weight_block_size is not None self.act_q_static = self.quant_config.activation_scheme == "static" # Use per-token quantization for better perf if dynamic and cutlass if not self.act_q_static and cutlass_fp8_supported(): @@ -273,51 +271,27 @@ class Fp8LinearMethod(LinearMethodBase): layer.weight_block_size = None if self.block_quant: - tp_size = getattr(layer, "tp_size", - get_tensor_model_parallel_world_size()) - assert self.quant_config.weight_block_size is not None - layer.weight_block_size = self.quant_config.weight_block_size - block_n, block_k = ( - self.quant_config.weight_block_size[0], - self.quant_config.weight_block_size[1], - ) - # Required by row parallel - if (tp_size > 1 - and input_size // input_size_per_partition == tp_size - and input_size_per_partition % block_k != 0): - raise ValueError( - f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible by " - f"weight quantization block_k = {block_k}.") - # Required by column parallel or enabling merged weights - is_tp_split = (tp_size > 1 and - output_size // output_size_per_partition == tp_size) - is_merged_gemm = len(output_partition_sizes) > 1 - if is_tp_split or is_merged_gemm: - sizes_to_check = output_partition_sizes - if not is_tp_split and is_merged_gemm: - # In case of merged matrices, we allow the last - # matrix to not be a multiple of block size - sizes_to_check = output_partition_sizes[:-1] - for output_partition_size in sizes_to_check: - if output_partition_size % block_n != 0: - raise ValueError( - f"Weight output_partition_size = " - f"{output_partition_size} is not divisible by " - f"weight quantization block_n = {block_n}.") + assert self.weight_block_size is not None + layer.weight_block_size = self.weight_block_size + validate_fp8_block_shape(layer, input_size, output_size, + input_size_per_partition, + output_partition_sizes, + self.weight_block_size) # WEIGHT - weight_dtype = (torch.float8_e4m3fn - if self.quant_config.is_checkpoint_fp8_serialized else - params_dtype) - - weight = ModelWeightParameter(data=torch.empty( - output_size_per_partition, - input_size_per_partition, - dtype=weight_dtype), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + if self.quant_config.is_checkpoint_fp8_serialized: + weight = create_fp8_weight_parameter(output_size_per_partition, + input_size_per_partition, + weight_loader) + else: + # For non-serialized checkpoints, use original dtype + weight = ModelWeightParameter(data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=params_dtype), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) layer.register_parameter("weight", weight) # If checkpoint is serialized fp8, load them. @@ -325,154 +299,87 @@ class Fp8LinearMethod(LinearMethodBase): if self.quant_config.is_checkpoint_fp8_serialized: # WEIGHT SCALE if not self.block_quant: - scale = PerTensorScaleParameter( - data=torch.empty(len(output_partition_sizes), - dtype=torch.float32), - weight_loader=weight_loader, - ) - scale[:] = torch.finfo(torch.float32).min + scale = create_fp8_scale_parameter(PerTensorScaleParameter, + output_partition_sizes, + input_size_per_partition, + None, weight_loader) set_weight_attrs(scale, {"scale_type": "weight_scale"}) layer.register_parameter("weight_scale", scale) else: - assert self.quant_config.activation_scheme == "dynamic" - scale = BlockQuantScaleParameter( - data=torch.empty( - (output_size_per_partition + block_n - 1) // block_n, - (input_size_per_partition + block_k - 1) // block_k, - dtype=torch.float32, - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader, - ) - scale[:] = torch.finfo(torch.float32).min + assert not self.act_q_static + assert self.weight_block_size is not None + scale = create_fp8_scale_parameter(BlockQuantScaleParameter, + output_partition_sizes, + input_size_per_partition, + self.weight_block_size, + weight_loader) set_weight_attrs(scale, {"scale_type": "weight_scale"}) # The weight_scale_inv name is intentional for deepseekv3 layer.register_parameter("weight_scale_inv", scale) # INPUT ACTIVATION SCALE - if self.quant_config.activation_scheme == "static": - scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) - - scale[:] = torch.finfo(torch.float32).min + if self.act_q_static: + scale = create_fp8_input_scale(output_partition_sizes, + weight_loader) set_weight_attrs(scale, {"scale_type": "input_scale"}) layer.register_parameter("input_scale", scale) else: layer.register_parameter("input_scale", None) - def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor: - # Pad the weight tensor. This is an optimization on ROCm platform, which - # can benefit from tensors located far enough from one another in memory - if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm() - and weight.stride(-1) == 1 - and (weight.stride(-2) * weight.element_size()) % 512 == 0): - num_pad = 256 // weight.element_size() - weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] - torch.cuda.empty_cache() - return weight - def process_weights_after_loading(self, layer: Module) -> None: size_k_first = True + input_scale = None # TODO(rob): refactor block quant into separate class. if self.block_quant: - assert self.quant_config.activation_scheme == "dynamic" + assert not self.act_q_static size_k_first = False - if current_platform.is_fp8_fnuz(): - weight, weight_scale_inv, _ = \ - normalize_e4m3fn_to_e4m3fnuz( - weight=layer.weight, - weight_scale=layer.weight_scale_inv) - else: - weight = layer.weight.data - weight_scale_inv = layer.weight_scale_inv.data - weight = self._maybe_pad_weight(weight) - - # Torch.compile cannot use Parameter subclasses. - layer.weight = Parameter(weight, requires_grad=False) - layer.weight_scale_inv = Parameter(weight_scale_inv, - requires_grad=False) + weight, weight_scale = process_fp8_weight_block_strategy( + layer.weight, layer.weight_scale_inv) + # Delete the weight_scale_inv parameter to avoid confusion + # with the weight_scale parameter + del layer.weight_scale_inv # If checkpoint not serialized fp8, quantize the weights. elif not self.quant_config.is_checkpoint_fp8_serialized: qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None) + weight = qweight.t() - # Update the layer with the new values. - layer.weight = Parameter(qweight.t(), requires_grad=False) - layer.weight_scale = Parameter(weight_scale, requires_grad=False) - # layer.input_scale is None indicates dynamic quant and scale is - # computed from input. - layer.input_scale = None - - # If checkpoint is fp8, handle that there are N scales for N + # If checkpoint is fp8 per-tensor, handle that there are N scales for N # shards in a fused module else: - layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data, - requires_grad=False) - if self.quant_config.activation_scheme == "static": - layer.input_scale = torch.nn.Parameter(layer.input_scale.data, - requires_grad=False) - weight = layer.weight weight_scale = layer.weight_scale # If using w8a8, torch._scaled_mm needs per tensor, so # requantize the logical shards as a single weight. if not self.use_marlin: - # Dequant -> Quant with max scale so we can run per tensor. - if current_platform.is_fp8_fnuz(): - weight, weight_scale, input_scale = \ - normalize_e4m3fn_to_e4m3fnuz( - weight=weight, - weight_scale=weight_scale, - input_scale=layer.input_scale) - if input_scale is not None: - layer.input_scale = Parameter(input_scale, - requires_grad=False) + weight, weight_scale, input_scale = ( + process_fp8_weight_tensor_strategy( + weight, weight_scale, layer.logical_widths, + getattr(layer, 'input_scale', None))) + if self.act_q_static: + assert input_scale is not None + input_scale = input_scale.max() + weight = weight.t() - weight_scale, weight = requantize_with_max_scale( - weight=weight, - weight_scale=weight_scale, - logical_widths=layer.logical_widths, - ) - - weight = self._maybe_pad_weight(weight) - # Update layer with new values. - layer.weight = Parameter(weight.t(), requires_grad=False) - layer.weight_scale = Parameter(weight_scale, requires_grad=False) - if self.quant_config.activation_scheme == "static": - layer.input_scale = Parameter(layer.input_scale.max(), - requires_grad=False) + # Update layer with new values. + layer.weight = Parameter(weight.data, requires_grad=False) + layer.weight_scale = Parameter(weight_scale.data, requires_grad=False) + layer.input_scale = Parameter( + input_scale, + requires_grad=False) if input_scale is not None else None if self.use_marlin: prepare_fp8_layer_for_marlin(layer, size_k_first) # Activations not quantized for marlin. del layer.input_scale + return - # On Blackwell or Hopper, if E8M0 for DeepGemm is used, we need to - # requantize the weight and input to the specific scale - # at the same time. - if is_deep_gemm_e8m0_used() and self.block_quant: - assert layer.weight_block_size is not None - block_sz = tuple(layer.weight_block_size) - requant_weight_ue8m0_inplace( - layer.weight.data, - layer.weight_scale_inv.data if hasattr( - layer, "weight_scale_inv") else layer.weight_scale.data, - block_sz, - ) - - # SM90 Block FP8 CUTLASS requires row-major weight scales - if (self.block_quant and current_platform.is_device_capability(90) - and self.cutlass_block_fp8_supported - and not should_use_deepgemm_for_fp8_linear( - torch.bfloat16, layer.weight)): - layer.weight_scale_inv = Parameter( - layer.weight_scale_inv.data.T.contiguous(), - requires_grad=False) + if self.block_quant: + maybe_post_process_fp8_weight_block( + layer, self.cutlass_block_fp8_supported) def apply(self, layer: torch.nn.Module, @@ -490,18 +397,12 @@ class Fp8LinearMethod(LinearMethodBase): bias=bias) if self.block_quant: - assert self.quant_config.weight_block_size is not None - - return torch.ops.vllm.apply_w8a8_block_fp8_linear( + return apply_fp8_block_linear( + layer, input=x, - weight=layer.weight, - block_size=self.quant_config.weight_block_size, - weight_scale=layer.weight_scale_inv, - input_scale=layer.input_scale, bias=bias, cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, - use_aiter_and_is_supported=self.use_aiter_and_is_supported, - ) + use_aiter_and_is_supported=self.use_aiter_and_is_supported) return self.fp8_linear.apply(input=x, weight=layer.weight, @@ -528,7 +429,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): super().__init__(layer.moe_config) self.layer = layer self.quant_config = quant_config - self.block_quant = self.quant_config.weight_block_size is not None + self.weight_block_size = self.quant_config.weight_block_size + self.block_quant = self.weight_block_size is not None self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None self.fused_experts: Optional[ @@ -590,12 +492,12 @@ class Fp8MoEMethod(FusedMoEMethodBase): if self.quant_config.is_checkpoint_fp8_serialized: params_dtype = torch.float8_e4m3fn if self.block_quant: - assert self.quant_config.weight_block_size is not None - layer.weight_block_size = self.quant_config.weight_block_size + assert self.weight_block_size is not None + layer.weight_block_size = self.weight_block_size tp_size = get_tensor_model_parallel_world_size() block_n, block_k = ( - self.quant_config.weight_block_size[0], - self.quant_config.weight_block_size[1], + self.weight_block_size[0], + self.weight_block_size[1], ) # NOTE: To ensure proper alignment of the block-wise quantization # scales, the output_size of the weights for both the gate and up @@ -952,7 +854,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): "BatchedTritonOrDeepGemmExperts(%s): " "max_tokens_per_rank=%s, block_size=%s, per_act_token=%s", self.__class__.__name__, max_num_tokens_per_rank, - self.quant_config.weight_block_size, False) + self.weight_block_size, False) return BatchedTritonOrDeepGemmExperts( max_num_tokens=max_num_tokens_per_rank, num_dispatchers=prepare_finalize.num_dispatchers(), @@ -969,8 +871,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): else: logger.debug( "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s", - self.__class__.__name__, self.quant_config.weight_block_size, - False) + self.__class__.__name__, self.weight_block_size, False) return TritonOrDeepGemmExperts( quant_config=self.moe_quant_config, allow_deep_gemm=self.allow_deep_gemm, @@ -988,7 +889,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): if self.block_quant else layer.w2_weight_scale), a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, - block_shape=self.quant_config.weight_block_size, + block_shape=self.weight_block_size, ) def apply( @@ -1046,7 +947,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): intermediate_size=layer.intermediate_size_per_partition, expert_offset=layer.ep_rank * layer.local_num_experts, local_num_experts=layer.local_num_experts, - block_shape=self.quant_config.weight_block_size, + block_shape=self.weight_block_size, routed_scaling=routed_scaling_factor, ) else: diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index bbe0c6f6d38e..fc12483de0c0 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -17,6 +17,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( group_broadcast) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( CUTLASS_BLOCK_FP8_SUPPORTED) +from vllm.model_executor.parameter import (BlockQuantScaleParameter, + ChannelQuantScaleParameter, + PerTensorScaleParameter) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import cdiv, direct_register_custom_op @@ -794,3 +797,220 @@ def requant_weight_ue8m0_inplace( # Write back the results in-place. w_q.copy_(w_requant) s_old.copy_(s_requant) + + +def check_aiter_fp8_linear_support() -> bool: + """AITER is only supported on ROCm and only for FP8_FNUZ + and at the moment are MI300 series""" + return (current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER + and envs.VLLM_ROCM_USE_AITER_LINEAR + and current_platform.is_fp8_fnuz()) + + +def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor: + """Pad the weight tensor. This is an optimization on ROCm platform, which + can benefit from tensors located far enough from one another in memory""" + if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm() + and weight.stride(-1) == 1 + and (weight.stride(-2) * weight.element_size()) % 512 == 0): + num_pad = 256 // weight.element_size() + import torch.nn.functional as F + weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] + torch.cuda.empty_cache() + return weight + + +def validate_fp8_block_shape(layer: torch.nn.Module, input_size: int, + output_size: int, input_size_per_partition: int, + output_partition_sizes: list[int], + block_size: list[int]) -> None: + """Validate block quantization shapes for tensor parallelism.""" + from vllm.distributed import get_tensor_model_parallel_world_size + + tp_size = getattr(layer, "tp_size", get_tensor_model_parallel_world_size()) + block_n, block_k = block_size[0], block_size[1] + + # Required by row parallel + if (tp_size > 1 and input_size // input_size_per_partition == tp_size + and input_size_per_partition % block_k != 0): + raise ValueError( + f"Weight input_size_per_partition = {input_size_per_partition} " + f"is not divisible by weight quantization block_k = {block_k}.") + + # Required by column parallel or enabling merged weights + is_tp_split = (tp_size > 1 + and output_size // sum(output_partition_sizes) == tp_size) + is_merged_gemm = len(output_partition_sizes) > 1 + if is_tp_split or is_merged_gemm: + sizes_to_check = output_partition_sizes + if not is_tp_split and is_merged_gemm: + # In case of merged matrices, we allow the last + # matrix to not be a multiple of block size + sizes_to_check = output_partition_sizes[:-1] + for output_partition_size in sizes_to_check: + if output_partition_size % block_n != 0: + raise ValueError( + f"Weight output_partition_size = " + f"{output_partition_size} is not divisible by " + f"weight quantization block_n = {block_n}.") + + +def create_fp8_weight_parameter( + output_size_per_partition: int, input_size_per_partition: int, + weight_loader: Optional[Callable]) -> torch.nn.Parameter: + """Create FP8 weight parameter.""" + from vllm.model_executor.parameter import ModelWeightParameter + + return ModelWeightParameter(data=torch.empty(output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + +def create_fp8_scale_parameter( + parameter_type: torch.nn.Parameter, output_partition_sizes: list[int], + input_size_per_partition: int, block_size: Optional[list[int]], + weight_loader: Optional[Callable]) -> torch.nn.Parameter: + """Create scale parameter based on quantization strategy.""" + if parameter_type == ChannelQuantScaleParameter: + scale = parameter_type(data=torch.empty( + (sum(output_partition_sizes), 1), dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader) + elif parameter_type == BlockQuantScaleParameter: + assert block_size is not None + block_n, block_k = block_size[0], block_size[1] + output_size_per_partition = sum(output_partition_sizes) + scale = parameter_type( + data=torch.empty( + (output_size_per_partition + block_n - 1) // block_n, + (input_size_per_partition + block_k - 1) // block_k, + dtype=torch.float32, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + elif parameter_type == PerTensorScaleParameter: + scale = parameter_type(data=torch.empty(len(output_partition_sizes), + dtype=torch.float32), + weight_loader=weight_loader) + else: + raise ValueError(f"Unknown parameter type: {parameter_type}") + + scale[:] = torch.finfo(torch.float32).min + return scale + + +def create_fp8_input_scale( + output_partition_sizes: list[int], + weight_loader: Optional[Callable]) -> torch.nn.Parameter: + """Create input scale parameter for static activation quantization.""" + from vllm.model_executor.parameter import PerTensorScaleParameter + + scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + scale[:] = torch.finfo(torch.float32).min + return scale + + +def process_fp8_weight_tensor_strategy( + weight: torch.Tensor, + weight_scale: torch.Tensor, + logical_widths: list[int], + input_scale: Optional[torch.Tensor] = None +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Process weights for tensor-wise quantization strategy.""" + from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale) + + if current_platform.is_fp8_fnuz(): + weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + weight=weight, weight_scale=weight_scale, input_scale=input_scale) + + # Requantize with max scale + weight_scale, weight = requantize_with_max_scale( + weight=weight, + weight_scale=weight_scale, + logical_widths=logical_widths, + ) + + weight = _maybe_pad_fp8_weight(weight) + return weight, weight_scale, input_scale + + +def process_fp8_weight_channel_strategy( + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Process weights for channel-wise quantization strategy.""" + from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + normalize_e4m3fn_to_e4m3fnuz) + + if current_platform.is_fp8_fnuz(): + weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + weight=weight, weight_scale=weight_scale, input_scale=input_scale) + + return weight, weight_scale, input_scale + + +def process_fp8_weight_block_strategy( + weight: torch.Tensor, + weight_scale: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Process weights for block-wise quantization strategy.""" + from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + normalize_e4m3fn_to_e4m3fnuz) + + if current_platform.is_fp8_fnuz(): + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=weight, weight_scale=weight_scale) + + weight = _maybe_pad_fp8_weight(weight) + return weight, weight_scale + + +def maybe_post_process_fp8_weight_block(layer: torch.nn.Module, + cutlass_block_fp8_supported: bool): + assert layer.weight_block_size is not None + + from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used, + should_use_deepgemm_for_fp8_linear) + + # On Blackwell or Hopper, if E8M0 for DeepGemm is used, we need to + # requantize the weight and input to the specific scale + # at the same time. + if is_deep_gemm_e8m0_used(): + block_sz = tuple(layer.weight_block_size) + requant_weight_ue8m0_inplace(layer.weight.data, + layer.weight_scale.data, block_sz) + # SM90 Block FP8 CUTLASS requires row-major weight scales + elif (current_platform.is_device_capability(90) + and cutlass_block_fp8_supported + and not should_use_deepgemm_for_fp8_linear(torch.bfloat16, + layer.weight)): + layer.weight_scale = torch.nn.Parameter( + layer.weight_scale.data.T.contiguous(), requires_grad=False) + + +def apply_fp8_block_linear(layer: torch.nn.Module, input: torch.Tensor, + bias: Optional[torch.Tensor], + cutlass_block_fp8_supported: bool, + use_aiter_and_is_supported: bool) -> torch.Tensor: + """Apply block-wise FP8 linear operation.""" + assert layer.weight_block_size is not None + + return torch.ops.vllm.apply_w8a8_block_fp8_linear( + input=input, + weight=layer.weight, + block_size=layer.weight_block_size, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + bias=bias, + cutlass_block_fp8_supported=cutlass_block_fp8_supported, + use_aiter_and_is_supported=use_aiter_and_is_supported, + )