Refactor dense FP8 tensor/channel/block utils and add CT FP8 block (#21404)

This commit is contained in:
Michael Goin 2025-09-18 08:53:45 -04:00 committed by GitHub
parent 470484a4f5
commit fbd6523ac0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 441 additions and 317 deletions

View File

@ -805,12 +805,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert loaded_shard_id < len(self.output_sizes) assert loaded_shard_id < len(self.output_sizes)
if isinstance(param, BlockQuantScaleParameter): if isinstance(param, BlockQuantScaleParameter):
from vllm.model_executor.layers.quantization.fp8 import (
Fp8LinearMethod, Fp8MoEMethod)
assert self.quant_method is not None assert self.quant_method is not None
assert isinstance(self.quant_method, # Assume the weight block size has been set by quant method
(Fp8LinearMethod, Fp8MoEMethod)) assert hasattr(self, "weight_block_size")
weight_block_size = self.quant_method.quant_config.weight_block_size weight_block_size = self.weight_block_size
assert weight_block_size is not None assert weight_block_size is not None
block_n, _ = weight_block_size[0], weight_block_size[1] block_n, _ = weight_block_size[0], weight_block_size[1]
shard_offset = ( shard_offset = (
@ -989,8 +987,10 @@ class QKVParallelLinear(ColumnParallelLinear):
# Note(simon): This is needed for Qwen3's fp8 quantization. # Note(simon): This is needed for Qwen3's fp8 quantization.
if isinstance(param, BlockQuantScaleParameter): if isinstance(param, BlockQuantScaleParameter):
assert self.quant_method is not None assert self.quant_method is not None
assert hasattr(self.quant_method, "quant_config") # Assume the weight block size has been set by quant method
weight_block_size = self.quant_method.quant_config.weight_block_size assert hasattr(self, "weight_block_size")
weight_block_size = self.weight_block_size
assert weight_block_size is not None
block_n, _ = weight_block_size[0], weight_block_size[1] block_n, _ = weight_block_size[0], weight_block_size[1]
shard_offset = (shard_offset + block_n - 1) // block_n shard_offset = (shard_offset + block_n - 1) // block_n
shard_size = (shard_size + block_n - 1) // block_n shard_size = (shard_size + block_n - 1) // block_n

View File

@ -12,7 +12,6 @@ from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy, QuantizationStrategy,
QuantizationType) QuantizationType)
from compressed_tensors.transform import TransformConfig from compressed_tensors.transform import TransformConfig
from pydantic import BaseModel
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
@ -268,7 +267,8 @@ class CompressedTensorsConfig(QuantizationConfig):
else: else:
return False return False
def _is_fp4a4_nvfp4(self, weight_quant: BaseModel, input_quant: BaseModel): def _is_fp4a4_nvfp4(self, weight_quant: QuantizationArgs,
input_quant: QuantizationArgs):
if weight_quant is None or input_quant is None: if weight_quant is None or input_quant is None:
return False return False
@ -288,8 +288,8 @@ class CompressedTensorsConfig(QuantizationConfig):
return (is_tensor_group_quant and is_float_type and is_4_bits return (is_tensor_group_quant and is_float_type and is_4_bits
and is_group_size_16 and is_symmetric) and is_group_size_16 and is_symmetric)
def _is_fp4a16_nvfp4(self, weight_quant: BaseModel, def _is_fp4a16_nvfp4(self, weight_quant: QuantizationArgs,
input_quant: BaseModel): input_quant: QuantizationArgs):
is_weight_only = weight_quant is not None and input_quant is None is_weight_only = weight_quant is not None and input_quant is None
is_tensor_group_quant = ( is_tensor_group_quant = (
@ -303,8 +303,8 @@ class CompressedTensorsConfig(QuantizationConfig):
return (is_weight_only and is_tensor_group_quant and is_float_type return (is_weight_only and is_tensor_group_quant and is_float_type
and is_4_bits and is_group_size_16 and is_symmetric) and is_4_bits and is_group_size_16 and is_symmetric)
def _is_static_tensor_w8a8(self, weight_quant: BaseModel, def _is_static_tensor_w8a8(self, weight_quant: QuantizationArgs,
input_quant: BaseModel) -> bool: input_quant: QuantizationArgs) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
weight_strategy = ( weight_strategy = (
weight_quant.strategy == QuantizationStrategy.TENSOR.value weight_quant.strategy == QuantizationStrategy.TENSOR.value
@ -317,8 +317,8 @@ class CompressedTensorsConfig(QuantizationConfig):
# Only symmetric weight quantization supported. # Only symmetric weight quantization supported.
return is_8_bits and is_tensor and weight_quant.symmetric and is_static return is_8_bits and is_tensor and weight_quant.symmetric and is_static
def _is_dynamic_token_w8a8(self, weight_quant: BaseModel, def _is_dynamic_token_w8a8(self, weight_quant: QuantizationArgs,
input_quant: BaseModel) -> bool: input_quant: QuantizationArgs) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
weight_strategy = ( weight_strategy = (
weight_quant.strategy == QuantizationStrategy.TENSOR.value weight_quant.strategy == QuantizationStrategy.TENSOR.value
@ -331,8 +331,8 @@ class CompressedTensorsConfig(QuantizationConfig):
# Only symmetric weight quantization supported. # Only symmetric weight quantization supported.
return is_8_bits and is_token and weight_quant.symmetric and is_dynamic return is_8_bits and is_token and weight_quant.symmetric and is_dynamic
def _is_dynamic_token_w4a8_int(self, weight_quant: BaseModel, def _is_dynamic_token_w4a8_int(self, weight_quant: QuantizationArgs,
input_quant: BaseModel) -> bool: input_quant: QuantizationArgs) -> bool:
is_weight_4_bits = weight_quant.num_bits == 4 is_weight_4_bits = weight_quant.num_bits == 4
is_activation_8_bits = input_quant.num_bits == 8 is_activation_8_bits = input_quant.num_bits == 8
weight_strategy = ( weight_strategy = (
@ -347,8 +347,8 @@ class CompressedTensorsConfig(QuantizationConfig):
return (is_weight_4_bits and is_activation_8_bits and is_token return (is_weight_4_bits and is_activation_8_bits and is_token
and weight_quant.symmetric and is_dynamic) and weight_quant.symmetric and is_dynamic)
def _is_fp8_w8a8(self, weight_quant: BaseModel, def _is_fp8_w8a8(self, weight_quant: QuantizationArgs,
input_quant: BaseModel) -> bool: input_quant: QuantizationArgs) -> bool:
# Confirm weights and activations quantized. # Confirm weights and activations quantized.
if weight_quant is None or input_quant is None: if weight_quant is None or input_quant is None:
return False return False
@ -358,11 +358,12 @@ class CompressedTensorsConfig(QuantizationConfig):
and input_quant.type == QuantizationType.FLOAT) and input_quant.type == QuantizationType.FLOAT)
is_symmetric_weight = weight_quant.symmetric is_symmetric_weight = weight_quant.symmetric
is_static_weight = not weight_quant.dynamic is_static_weight = not weight_quant.dynamic
is_per_tensor_or_channel_weight = (weight_quant.strategy in [ is_tensor_or_channel_or_block_weight = (weight_quant.strategy in [
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL,
QuantizationStrategy.BLOCK
]) ])
if not (is_floating_point and is_symmetric_weight and is_static_weight if not (is_floating_point and is_symmetric_weight and is_static_weight
and is_per_tensor_or_channel_weight): and is_tensor_or_channel_or_block_weight):
return False return False
# Dynamic quantization is always supported if weights supported. # Dynamic quantization is always supported if weights supported.
@ -375,8 +376,8 @@ class CompressedTensorsConfig(QuantizationConfig):
input_quant.strategy == QuantizationStrategy.TENSOR) input_quant.strategy == QuantizationStrategy.TENSOR)
return is_symmetric_activation and is_per_tensor_activation return is_symmetric_activation and is_per_tensor_activation
def _is_fp8_w4a8(self, weight_quant: BaseModel, def _is_fp8_w4a8(self, weight_quant: QuantizationArgs,
input_quant: BaseModel) -> bool: input_quant: QuantizationArgs) -> bool:
if not weight_quant or not input_quant: if not weight_quant or not input_quant:
return False return False
is_weight_4_bits = weight_quant.num_bits == 4 is_weight_4_bits = weight_quant.num_bits == 4
@ -392,24 +393,24 @@ class CompressedTensorsConfig(QuantizationConfig):
return (is_weight_4_bits and is_activation_8_bits and is_token return (is_weight_4_bits and is_activation_8_bits and is_token
and is_symmetric and is_dynamic) and is_symmetric and is_dynamic)
def _is_fp8_w4a8_sm90(self, weight_quant: BaseModel, def _is_fp8_w4a8_sm90(self, weight_quant: QuantizationArgs,
input_quant: BaseModel) -> bool: input_quant: QuantizationArgs) -> bool:
return (self._check_scheme_supported(90, error=False, match_exact=True) return (self._check_scheme_supported(90, error=False, match_exact=True)
and self._is_fp8_w4a8(weight_quant, input_quant)) and self._is_fp8_w4a8(weight_quant, input_quant))
def _is_fp8_w8a8_sm90(self, weight_quant: BaseModel, def _is_fp8_w8a8_sm90(self, weight_quant: QuantizationArgs,
input_quant: BaseModel) -> bool: input_quant: QuantizationArgs) -> bool:
return (self._check_scheme_supported(90, error=False, match_exact=True) return (self._check_scheme_supported(90, error=False, match_exact=True)
and self._is_fp8_w8a8(weight_quant, input_quant)) and self._is_fp8_w8a8(weight_quant, input_quant))
def _is_fp8_w8a8_sm100(self, weight_quant: BaseModel, def _is_fp8_w8a8_sm100(self, weight_quant: QuantizationArgs,
input_quant: BaseModel) -> bool: input_quant: QuantizationArgs) -> bool:
return (self._check_scheme_supported( return (self._check_scheme_supported(
100, error=False, match_exact=True) 100, error=False, match_exact=True)
and self._is_fp8_w8a8(weight_quant, input_quant)) and self._is_fp8_w8a8(weight_quant, input_quant))
def _is_fp8_w8a16(self, weight_quant: BaseModel, def _is_fp8_w8a16(self, weight_quant: QuantizationArgs,
input_quant: BaseModel) -> bool: input_quant: QuantizationArgs) -> bool:
# Confirm weights quantized. # Confirm weights quantized.
if weight_quant is None: if weight_quant is None:
return False return False
@ -421,18 +422,19 @@ class CompressedTensorsConfig(QuantizationConfig):
# Confirm weight scheme is supported. # Confirm weight scheme is supported.
is_symmetric_weight = weight_quant.symmetric is_symmetric_weight = weight_quant.symmetric
is_static_weight = not weight_quant.dynamic is_static_weight = not weight_quant.dynamic
is_per_tensor_or_channel_weight = (weight_quant.strategy in [ is_tensor_or_channel_or_block_weight = (weight_quant.strategy in [
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL,
QuantizationStrategy.BLOCK
]) ])
if not (is_symmetric_weight and is_static_weight # noqa: SIM103 if not (is_symmetric_weight and is_static_weight # noqa: SIM103
and is_per_tensor_or_channel_weight): and is_tensor_or_channel_or_block_weight):
return False return False
# All conditions satisfied. # All conditions satisfied.
return True return True
def _is_wNa16_group_channel(self, weight_quant: BaseModel, def _is_wNa16_group_channel(self, weight_quant: QuantizationArgs,
input_quant: BaseModel) -> bool: input_quant: QuantizationArgs) -> bool:
input_quant_none = input_quant is None input_quant_none = input_quant is None
is_channel_group = ( is_channel_group = (
weight_quant.strategy == QuantizationStrategy.CHANNEL.value weight_quant.strategy == QuantizationStrategy.CHANNEL.value
@ -443,8 +445,8 @@ class CompressedTensorsConfig(QuantizationConfig):
def _get_scheme_from_parts( def _get_scheme_from_parts(
self, self,
weight_quant: BaseModel, weight_quant: QuantizationArgs,
input_quant: BaseModel, input_quant: QuantizationArgs,
format: Optional[str] = None) -> "CompressedTensorsScheme": format: Optional[str] = None) -> "CompressedTensorsScheme":
# use the per-layer format if defined, otherwise, use global format # use the per-layer format if defined, otherwise, use global format
@ -496,7 +498,7 @@ class CompressedTensorsConfig(QuantizationConfig):
CompressedTensorsW8A8Fp8.get_min_capability(), error=False) CompressedTensorsW8A8Fp8.get_min_capability(), error=False)
if is_fp8_w8a8_supported: if is_fp8_w8a8_supported:
return CompressedTensorsW8A8Fp8( return CompressedTensorsW8A8Fp8(
strategy=weight_quant.strategy, weight_quant=weight_quant,
is_static_input_scheme=(input_quant is_static_input_scheme=(input_quant
and not input_quant.dynamic)) and not input_quant.dynamic))
else: else:

View File

@ -4,28 +4,41 @@
from typing import Callable, Optional from typing import Callable, Optional
import torch import torch
from compressed_tensors.quantization import QuantizationStrategy from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy)
from torch.nn import Parameter from torch.nn import Parameter
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme) CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
apply_fp8_block_linear, check_aiter_fp8_linear_support,
create_fp8_input_scale, create_fp8_scale_parameter,
create_fp8_weight_parameter, maybe_post_process_fp8_weight_block,
process_fp8_weight_block_strategy, process_fp8_weight_channel_strategy,
process_fp8_weight_tensor_strategy, validate_fp8_block_shape)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape) GroupShape)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz, Fp8LinearOp, cutlass_block_fp8_supported, maybe_create_device_identity)
requantize_with_max_scale) from vllm.model_executor.parameter import (BlockQuantScaleParameter,
from vllm.model_executor.parameter import (ChannelQuantScaleParameter, ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter) PerTensorScaleParameter)
from vllm.platforms import current_platform
__all__ = ["CompressedTensorsW8A8Fp8"] __all__ = ["CompressedTensorsW8A8Fp8"]
strategy_to_parameter_type = {
QuantizationStrategy.BLOCK: BlockQuantScaleParameter,
QuantizationStrategy.CHANNEL: ChannelQuantScaleParameter,
QuantizationStrategy.TENSOR: PerTensorScaleParameter,
}
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
def __init__(self, strategy: str, is_static_input_scheme: bool): def __init__(self, weight_quant: QuantizationArgs,
self.strategy = strategy is_static_input_scheme: bool):
self.weight_quant = weight_quant
self.strategy = weight_quant.strategy
self.out_dtype = torch.get_default_dtype() self.out_dtype = torch.get_default_dtype()
self.is_static_input_scheme = is_static_input_scheme self.is_static_input_scheme = is_static_input_scheme
self.act_q_group_shape = GroupShape.PER_TENSOR \ self.act_q_group_shape = GroupShape.PER_TENSOR \
@ -34,61 +47,84 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
act_quant_static=self.is_static_input_scheme, act_quant_static=self.is_static_input_scheme,
act_quant_group_shape=self.act_q_group_shape) act_quant_group_shape=self.act_q_group_shape)
self.weight_block_size = self.weight_quant.block_structure
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
# lovelace and up # lovelace and up
return 89 return 89
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
weight_loader: Callable, **kwargs):
maybe_create_device_identity()
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
layer.weight_block_size = None
if self.strategy == QuantizationStrategy.BLOCK:
assert self.weight_block_size is not None
layer.weight_block_size = self.weight_block_size
# Validate block quantization shapes
validate_fp8_block_shape(layer, input_size, output_size,
input_size_per_partition,
output_partition_sizes,
self.weight_block_size)
# WEIGHT
weight = create_fp8_weight_parameter(output_size_per_partition,
input_size_per_partition,
weight_loader)
layer.register_parameter("weight", weight)
# WEIGHT SCALE
weight_scale = create_fp8_scale_parameter(
strategy_to_parameter_type[self.strategy], output_partition_sizes,
input_size_per_partition, layer.weight_block_size, weight_loader)
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE
if self.is_static_input_scheme:
input_scale = create_fp8_input_scale(output_partition_sizes,
weight_loader)
layer.register_parameter("input_scale", input_scale)
def process_weights_after_loading(self, layer) -> None: def process_weights_after_loading(self, layer) -> None:
# If per tensor, when we have a fused module (e.g. QKV) with per
# tensor scales (thus N scales being passed to the kernel),
# requantize so we can always run per tensor
if self.strategy == QuantizationStrategy.TENSOR: if self.strategy == QuantizationStrategy.TENSOR:
max_w_scale, weight = requantize_with_max_scale( weight, weight_scale, input_scale = (
weight=layer.weight, process_fp8_weight_tensor_strategy(
weight_scale=layer.weight_scale, layer.weight, layer.weight_scale, layer.logical_widths,
logical_widths=layer.logical_widths, getattr(layer, 'input_scale', None)))
) weight = weight.t()
if current_platform.is_fp8_fnuz():
input_scale = getattr(layer, 'input_scale', None)
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=max_w_scale,
input_scale=input_scale)
if input_scale is not None:
layer.input_scale = Parameter(input_scale,
requires_grad=False)
layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
# If channelwise, scales are already lined up, so just transpose.
elif self.strategy == QuantizationStrategy.CHANNEL: elif self.strategy == QuantizationStrategy.CHANNEL:
weight = layer.weight weight, weight_scale, input_scale = (
process_fp8_weight_channel_strategy(
layer.weight, layer.weight_scale,
getattr(layer, 'input_scale', None)))
weight = weight.t()
if current_platform.is_fp8_fnuz(): elif self.strategy == QuantizationStrategy.BLOCK:
input_scale = getattr(layer, 'input_scale', None) assert self.is_static_input_scheme is False
weight, weight_scale = process_fp8_weight_block_strategy(
weight, weight_scale, input_scale = \ layer.weight, layer.weight_scale)
normalize_e4m3fn_to_e4m3fnuz( input_scale = None
weight=weight,
weight_scale=layer.weight_scale,
input_scale=input_scale)
if input_scale is not None:
layer.input_scale = Parameter(input_scale,
requires_grad=False)
else:
weight_scale = layer.weight_scale.data
layer.weight = Parameter(weight.t(), requires_grad=False)
# required by torch.compile to be torch.nn.Parameter
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
else: else:
raise ValueError(f"Unknown quantization strategy {self.strategy}") raise ValueError(f"Unknown quantization strategy {self.strategy}")
# required by torch.compile to be torch.nn.Parameter
layer.weight = Parameter(weight.data, requires_grad=False)
layer.weight_scale = Parameter(weight_scale.data, requires_grad=False)
if input_scale is not None:
layer.input_scale = Parameter(input_scale.data,
requires_grad=False)
# INPUT SCALE # INPUT SCALE
if self.is_static_input_scheme and hasattr(layer, 'input_scale'): if self.is_static_input_scheme and hasattr(layer, 'input_scale'):
layer.input_scale = Parameter(layer.input_scale.max(), layer.input_scale = Parameter(layer.input_scale.max(),
@ -96,58 +132,23 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
else: else:
layer.input_scale = None layer.input_scale = None
def create_weights(self, layer: torch.nn.Module, if self.strategy == QuantizationStrategy.BLOCK:
output_partition_sizes: list[int], maybe_post_process_fp8_weight_block(
input_size_per_partition: int, layer, self.cutlass_block_fp8_supported)
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
maybe_create_device_identity()
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
# WEIGHT
weight = ModelWeightParameter(data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=torch.float8_e4m3fn),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight)
# WEIGHT SCALE
# TODO: update create_xxx_parameter functions to return
# the newly added parameters
if self.strategy == QuantizationStrategy.CHANNEL:
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
else:
assert self.strategy == QuantizationStrategy.TENSOR
weight_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
# min requirement for fp8 kernels
weight_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE
if self.is_static_input_scheme:
input_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
input_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", input_scale)
def apply_weights(self, def apply_weights(self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if layer.weight_block_size is not None:
return apply_fp8_block_linear(
layer,
input=x,
bias=bias,
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported)
return self.fp8_linear.apply(input=x, return self.fp8_linear.apply(input=x,
weight=layer.weight, weight=layer.weight,
weight_scale=layer.weight_scale, weight_scale=layer.weight_scale,

View File

@ -4,7 +4,6 @@
from typing import TYPE_CHECKING, Any, Callable, Optional, Union from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch import torch
import torch.nn.functional as F
from torch.nn import Module from torch.nn import Module
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
@ -32,8 +31,12 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights,
select_cutlass_fp8_gemm_impl, swap_w13_to_w31) select_cutlass_fp8_gemm_impl, swap_w13_to_w31)
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace, apply_fp8_block_linear, check_aiter_fp8_linear_support,
should_use_deepgemm_for_fp8_linear) create_fp8_input_scale, create_fp8_scale_parameter,
create_fp8_weight_parameter, get_col_major_tma_aligned_tensor,
maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy,
process_fp8_weight_tensor_strategy, requant_weight_ue8m0_inplace,
validate_fp8_block_shape)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin, apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
prepare_moe_fp8_layer_for_marlin) prepare_moe_fp8_layer_for_marlin)
@ -42,8 +45,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported, Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported,
cutlass_fp8_supported, maybe_create_device_identity, cutlass_fp8_supported, maybe_create_device_identity,
normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
requantize_with_max_scale)
from vllm.model_executor.parameter import (BlockQuantScaleParameter, from vllm.model_executor.parameter import (BlockQuantScaleParameter,
ModelWeightParameter, ModelWeightParameter,
PerTensorScaleParameter) PerTensorScaleParameter)
@ -233,14 +235,10 @@ class Fp8LinearMethod(LinearMethodBase):
if current_platform.is_rocm(): if current_platform.is_rocm():
self.use_marlin = False self.use_marlin = False
# AITER is only supported on ROCm and only for FP8_FNUZ self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
# and at the moment are MI300 series
self.use_aiter_and_is_supported = (current_platform.is_rocm()
and envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_AITER_LINEAR
and current_platform.is_fp8_fnuz())
self.block_quant = self.quant_config.weight_block_size is not None self.weight_block_size = self.quant_config.weight_block_size
self.block_quant = self.weight_block_size is not None
self.act_q_static = self.quant_config.activation_scheme == "static" self.act_q_static = self.quant_config.activation_scheme == "static"
# Use per-token quantization for better perf if dynamic and cutlass # Use per-token quantization for better perf if dynamic and cutlass
if not self.act_q_static and cutlass_fp8_supported(): if not self.act_q_static and cutlass_fp8_supported():
@ -273,51 +271,27 @@ class Fp8LinearMethod(LinearMethodBase):
layer.weight_block_size = None layer.weight_block_size = None
if self.block_quant: if self.block_quant:
tp_size = getattr(layer, "tp_size", assert self.weight_block_size is not None
get_tensor_model_parallel_world_size()) layer.weight_block_size = self.weight_block_size
assert self.quant_config.weight_block_size is not None validate_fp8_block_shape(layer, input_size, output_size,
layer.weight_block_size = self.quant_config.weight_block_size input_size_per_partition,
block_n, block_k = ( output_partition_sizes,
self.quant_config.weight_block_size[0], self.weight_block_size)
self.quant_config.weight_block_size[1],
)
# Required by row parallel
if (tp_size > 1
and input_size // input_size_per_partition == tp_size
and input_size_per_partition % block_k != 0):
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"weight quantization block_k = {block_k}.")
# Required by column parallel or enabling merged weights
is_tp_split = (tp_size > 1 and
output_size // output_size_per_partition == tp_size)
is_merged_gemm = len(output_partition_sizes) > 1
if is_tp_split or is_merged_gemm:
sizes_to_check = output_partition_sizes
if not is_tp_split and is_merged_gemm:
# In case of merged matrices, we allow the last
# matrix to not be a multiple of block size
sizes_to_check = output_partition_sizes[:-1]
for output_partition_size in sizes_to_check:
if output_partition_size % block_n != 0:
raise ValueError(
f"Weight output_partition_size = "
f"{output_partition_size} is not divisible by "
f"weight quantization block_n = {block_n}.")
# WEIGHT # WEIGHT
weight_dtype = (torch.float8_e4m3fn if self.quant_config.is_checkpoint_fp8_serialized:
if self.quant_config.is_checkpoint_fp8_serialized else weight = create_fp8_weight_parameter(output_size_per_partition,
params_dtype) input_size_per_partition,
weight_loader)
weight = ModelWeightParameter(data=torch.empty( else:
output_size_per_partition, # For non-serialized checkpoints, use original dtype
input_size_per_partition, weight = ModelWeightParameter(data=torch.empty(
dtype=weight_dtype), output_size_per_partition,
input_dim=1, input_size_per_partition,
output_dim=0, dtype=params_dtype),
weight_loader=weight_loader) input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight) layer.register_parameter("weight", weight)
# If checkpoint is serialized fp8, load them. # If checkpoint is serialized fp8, load them.
@ -325,154 +299,87 @@ class Fp8LinearMethod(LinearMethodBase):
if self.quant_config.is_checkpoint_fp8_serialized: if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALE # WEIGHT SCALE
if not self.block_quant: if not self.block_quant:
scale = PerTensorScaleParameter( scale = create_fp8_scale_parameter(PerTensorScaleParameter,
data=torch.empty(len(output_partition_sizes), output_partition_sizes,
dtype=torch.float32), input_size_per_partition,
weight_loader=weight_loader, None, weight_loader)
)
scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {"scale_type": "weight_scale"}) set_weight_attrs(scale, {"scale_type": "weight_scale"})
layer.register_parameter("weight_scale", scale) layer.register_parameter("weight_scale", scale)
else: else:
assert self.quant_config.activation_scheme == "dynamic" assert not self.act_q_static
scale = BlockQuantScaleParameter( assert self.weight_block_size is not None
data=torch.empty( scale = create_fp8_scale_parameter(BlockQuantScaleParameter,
(output_size_per_partition + block_n - 1) // block_n, output_partition_sizes,
(input_size_per_partition + block_k - 1) // block_k, input_size_per_partition,
dtype=torch.float32, self.weight_block_size,
), weight_loader)
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {"scale_type": "weight_scale"}) set_weight_attrs(scale, {"scale_type": "weight_scale"})
# The weight_scale_inv name is intentional for deepseekv3 # The weight_scale_inv name is intentional for deepseekv3
layer.register_parameter("weight_scale_inv", scale) layer.register_parameter("weight_scale_inv", scale)
# INPUT ACTIVATION SCALE # INPUT ACTIVATION SCALE
if self.quant_config.activation_scheme == "static": if self.act_q_static:
scale = PerTensorScaleParameter(data=torch.empty( scale = create_fp8_input_scale(output_partition_sizes,
len(output_partition_sizes), dtype=torch.float32), weight_loader)
weight_loader=weight_loader)
scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {"scale_type": "input_scale"}) set_weight_attrs(scale, {"scale_type": "input_scale"})
layer.register_parameter("input_scale", scale) layer.register_parameter("input_scale", scale)
else: else:
layer.register_parameter("input_scale", None) layer.register_parameter("input_scale", None)
def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
# Pad the weight tensor. This is an optimization on ROCm platform, which
# can benefit from tensors located far enough from one another in memory
if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm()
and weight.stride(-1) == 1
and (weight.stride(-2) * weight.element_size()) % 512 == 0):
num_pad = 256 // weight.element_size()
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
torch.cuda.empty_cache()
return weight
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
size_k_first = True size_k_first = True
input_scale = None
# TODO(rob): refactor block quant into separate class. # TODO(rob): refactor block quant into separate class.
if self.block_quant: if self.block_quant:
assert self.quant_config.activation_scheme == "dynamic" assert not self.act_q_static
size_k_first = False size_k_first = False
if current_platform.is_fp8_fnuz():
weight, weight_scale_inv, _ = \
normalize_e4m3fn_to_e4m3fnuz(
weight=layer.weight,
weight_scale=layer.weight_scale_inv)
else:
weight = layer.weight.data
weight_scale_inv = layer.weight_scale_inv.data
weight = self._maybe_pad_weight(weight) weight, weight_scale = process_fp8_weight_block_strategy(
layer.weight, layer.weight_scale_inv)
# Torch.compile cannot use Parameter subclasses. # Delete the weight_scale_inv parameter to avoid confusion
layer.weight = Parameter(weight, requires_grad=False) # with the weight_scale parameter
layer.weight_scale_inv = Parameter(weight_scale_inv, del layer.weight_scale_inv
requires_grad=False)
# If checkpoint not serialized fp8, quantize the weights. # If checkpoint not serialized fp8, quantize the weights.
elif not self.quant_config.is_checkpoint_fp8_serialized: elif not self.quant_config.is_checkpoint_fp8_serialized:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
scale=None) scale=None)
weight = qweight.t()
# Update the layer with the new values. # If checkpoint is fp8 per-tensor, handle that there are N scales for N
layer.weight = Parameter(qweight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
# layer.input_scale is None indicates dynamic quant and scale is
# computed from input.
layer.input_scale = None
# If checkpoint is fp8, handle that there are N scales for N
# shards in a fused module # shards in a fused module
else: else:
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
requires_grad=False)
if self.quant_config.activation_scheme == "static":
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
requires_grad=False)
weight = layer.weight weight = layer.weight
weight_scale = layer.weight_scale weight_scale = layer.weight_scale
# If using w8a8, torch._scaled_mm needs per tensor, so # If using w8a8, torch._scaled_mm needs per tensor, so
# requantize the logical shards as a single weight. # requantize the logical shards as a single weight.
if not self.use_marlin: if not self.use_marlin:
# Dequant -> Quant with max scale so we can run per tensor. weight, weight_scale, input_scale = (
if current_platform.is_fp8_fnuz(): process_fp8_weight_tensor_strategy(
weight, weight_scale, input_scale = \ weight, weight_scale, layer.logical_widths,
normalize_e4m3fn_to_e4m3fnuz( getattr(layer, 'input_scale', None)))
weight=weight, if self.act_q_static:
weight_scale=weight_scale, assert input_scale is not None
input_scale=layer.input_scale) input_scale = input_scale.max()
if input_scale is not None: weight = weight.t()
layer.input_scale = Parameter(input_scale,
requires_grad=False)
weight_scale, weight = requantize_with_max_scale( # Update layer with new values.
weight=weight, layer.weight = Parameter(weight.data, requires_grad=False)
weight_scale=weight_scale, layer.weight_scale = Parameter(weight_scale.data, requires_grad=False)
logical_widths=layer.logical_widths, layer.input_scale = Parameter(
) input_scale,
requires_grad=False) if input_scale is not None else None
weight = self._maybe_pad_weight(weight)
# Update layer with new values.
layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
if self.quant_config.activation_scheme == "static":
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
if self.use_marlin: if self.use_marlin:
prepare_fp8_layer_for_marlin(layer, size_k_first) prepare_fp8_layer_for_marlin(layer, size_k_first)
# Activations not quantized for marlin. # Activations not quantized for marlin.
del layer.input_scale del layer.input_scale
return
# On Blackwell or Hopper, if E8M0 for DeepGemm is used, we need to if self.block_quant:
# requantize the weight and input to the specific scale maybe_post_process_fp8_weight_block(
# at the same time. layer, self.cutlass_block_fp8_supported)
if is_deep_gemm_e8m0_used() and self.block_quant:
assert layer.weight_block_size is not None
block_sz = tuple(layer.weight_block_size)
requant_weight_ue8m0_inplace(
layer.weight.data,
layer.weight_scale_inv.data if hasattr(
layer, "weight_scale_inv") else layer.weight_scale.data,
block_sz,
)
# SM90 Block FP8 CUTLASS requires row-major weight scales
if (self.block_quant and current_platform.is_device_capability(90)
and self.cutlass_block_fp8_supported
and not should_use_deepgemm_for_fp8_linear(
torch.bfloat16, layer.weight)):
layer.weight_scale_inv = Parameter(
layer.weight_scale_inv.data.T.contiguous(),
requires_grad=False)
def apply(self, def apply(self,
layer: torch.nn.Module, layer: torch.nn.Module,
@ -490,18 +397,12 @@ class Fp8LinearMethod(LinearMethodBase):
bias=bias) bias=bias)
if self.block_quant: if self.block_quant:
assert self.quant_config.weight_block_size is not None return apply_fp8_block_linear(
layer,
return torch.ops.vllm.apply_w8a8_block_fp8_linear(
input=x, input=x,
weight=layer.weight,
block_size=self.quant_config.weight_block_size,
weight_scale=layer.weight_scale_inv,
input_scale=layer.input_scale,
bias=bias, bias=bias,
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported, use_aiter_and_is_supported=self.use_aiter_and_is_supported)
)
return self.fp8_linear.apply(input=x, return self.fp8_linear.apply(input=x,
weight=layer.weight, weight=layer.weight,
@ -528,7 +429,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
super().__init__(layer.moe_config) super().__init__(layer.moe_config)
self.layer = layer self.layer = layer
self.quant_config = quant_config self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None self.weight_block_size = self.quant_config.weight_block_size
self.block_quant = self.weight_block_size is not None
self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
self.fused_experts: Optional[ self.fused_experts: Optional[
@ -590,12 +492,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if self.quant_config.is_checkpoint_fp8_serialized: if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn params_dtype = torch.float8_e4m3fn
if self.block_quant: if self.block_quant:
assert self.quant_config.weight_block_size is not None assert self.weight_block_size is not None
layer.weight_block_size = self.quant_config.weight_block_size layer.weight_block_size = self.weight_block_size
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
block_n, block_k = ( block_n, block_k = (
self.quant_config.weight_block_size[0], self.weight_block_size[0],
self.quant_config.weight_block_size[1], self.weight_block_size[1],
) )
# NOTE: To ensure proper alignment of the block-wise quantization # NOTE: To ensure proper alignment of the block-wise quantization
# scales, the output_size of the weights for both the gate and up # scales, the output_size of the weights for both the gate and up
@ -952,7 +854,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
"BatchedTritonOrDeepGemmExperts(%s): " "BatchedTritonOrDeepGemmExperts(%s): "
"max_tokens_per_rank=%s, block_size=%s, per_act_token=%s", "max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
self.__class__.__name__, max_num_tokens_per_rank, self.__class__.__name__, max_num_tokens_per_rank,
self.quant_config.weight_block_size, False) self.weight_block_size, False)
return BatchedTritonOrDeepGemmExperts( return BatchedTritonOrDeepGemmExperts(
max_num_tokens=max_num_tokens_per_rank, max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(), num_dispatchers=prepare_finalize.num_dispatchers(),
@ -969,8 +871,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
else: else:
logger.debug( logger.debug(
"TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s", "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
self.__class__.__name__, self.quant_config.weight_block_size, self.__class__.__name__, self.weight_block_size, False)
False)
return TritonOrDeepGemmExperts( return TritonOrDeepGemmExperts(
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
allow_deep_gemm=self.allow_deep_gemm, allow_deep_gemm=self.allow_deep_gemm,
@ -988,7 +889,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if self.block_quant else layer.w2_weight_scale), if self.block_quant else layer.w2_weight_scale),
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size, block_shape=self.weight_block_size,
) )
def apply( def apply(
@ -1046,7 +947,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
intermediate_size=layer.intermediate_size_per_partition, intermediate_size=layer.intermediate_size_per_partition,
expert_offset=layer.ep_rank * layer.local_num_experts, expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts, local_num_experts=layer.local_num_experts,
block_shape=self.quant_config.weight_block_size, block_shape=self.weight_block_size,
routed_scaling=routed_scaling_factor, routed_scaling=routed_scaling_factor,
) )
else: else:

View File

@ -17,6 +17,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
group_broadcast) group_broadcast)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED) CUTLASS_BLOCK_FP8_SUPPORTED)
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
ChannelQuantScaleParameter,
PerTensorScaleParameter)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, direct_register_custom_op from vllm.utils import cdiv, direct_register_custom_op
@ -794,3 +797,220 @@ def requant_weight_ue8m0_inplace(
# Write back the results in-place. # Write back the results in-place.
w_q.copy_(w_requant) w_q.copy_(w_requant)
s_old.copy_(s_requant) s_old.copy_(s_requant)
def check_aiter_fp8_linear_support() -> bool:
"""AITER is only supported on ROCm and only for FP8_FNUZ
and at the moment are MI300 series"""
return (current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_AITER_LINEAR
and current_platform.is_fp8_fnuz())
def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor:
"""Pad the weight tensor. This is an optimization on ROCm platform, which
can benefit from tensors located far enough from one another in memory"""
if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm()
and weight.stride(-1) == 1
and (weight.stride(-2) * weight.element_size()) % 512 == 0):
num_pad = 256 // weight.element_size()
import torch.nn.functional as F
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
torch.cuda.empty_cache()
return weight
def validate_fp8_block_shape(layer: torch.nn.Module, input_size: int,
output_size: int, input_size_per_partition: int,
output_partition_sizes: list[int],
block_size: list[int]) -> None:
"""Validate block quantization shapes for tensor parallelism."""
from vllm.distributed import get_tensor_model_parallel_world_size
tp_size = getattr(layer, "tp_size", get_tensor_model_parallel_world_size())
block_n, block_k = block_size[0], block_size[1]
# Required by row parallel
if (tp_size > 1 and input_size // input_size_per_partition == tp_size
and input_size_per_partition % block_k != 0):
raise ValueError(
f"Weight input_size_per_partition = {input_size_per_partition} "
f"is not divisible by weight quantization block_k = {block_k}.")
# Required by column parallel or enabling merged weights
is_tp_split = (tp_size > 1
and output_size // sum(output_partition_sizes) == tp_size)
is_merged_gemm = len(output_partition_sizes) > 1
if is_tp_split or is_merged_gemm:
sizes_to_check = output_partition_sizes
if not is_tp_split and is_merged_gemm:
# In case of merged matrices, we allow the last
# matrix to not be a multiple of block size
sizes_to_check = output_partition_sizes[:-1]
for output_partition_size in sizes_to_check:
if output_partition_size % block_n != 0:
raise ValueError(
f"Weight output_partition_size = "
f"{output_partition_size} is not divisible by "
f"weight quantization block_n = {block_n}.")
def create_fp8_weight_parameter(
output_size_per_partition: int, input_size_per_partition: int,
weight_loader: Optional[Callable]) -> torch.nn.Parameter:
"""Create FP8 weight parameter."""
from vllm.model_executor.parameter import ModelWeightParameter
return ModelWeightParameter(data=torch.empty(output_size_per_partition,
input_size_per_partition,
dtype=torch.float8_e4m3fn),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
def create_fp8_scale_parameter(
parameter_type: torch.nn.Parameter, output_partition_sizes: list[int],
input_size_per_partition: int, block_size: Optional[list[int]],
weight_loader: Optional[Callable]) -> torch.nn.Parameter:
"""Create scale parameter based on quantization strategy."""
if parameter_type == ChannelQuantScaleParameter:
scale = parameter_type(data=torch.empty(
(sum(output_partition_sizes), 1), dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
elif parameter_type == BlockQuantScaleParameter:
assert block_size is not None
block_n, block_k = block_size[0], block_size[1]
output_size_per_partition = sum(output_partition_sizes)
scale = parameter_type(
data=torch.empty(
(output_size_per_partition + block_n - 1) // block_n,
(input_size_per_partition + block_k - 1) // block_k,
dtype=torch.float32,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
elif parameter_type == PerTensorScaleParameter:
scale = parameter_type(data=torch.empty(len(output_partition_sizes),
dtype=torch.float32),
weight_loader=weight_loader)
else:
raise ValueError(f"Unknown parameter type: {parameter_type}")
scale[:] = torch.finfo(torch.float32).min
return scale
def create_fp8_input_scale(
output_partition_sizes: list[int],
weight_loader: Optional[Callable]) -> torch.nn.Parameter:
"""Create input scale parameter for static activation quantization."""
from vllm.model_executor.parameter import PerTensorScaleParameter
scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
scale[:] = torch.finfo(torch.float32).min
return scale
def process_fp8_weight_tensor_strategy(
weight: torch.Tensor,
weight_scale: torch.Tensor,
logical_widths: list[int],
input_scale: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Process weights for tensor-wise quantization strategy."""
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale)
if current_platform.is_fp8_fnuz():
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight_scale=weight_scale, input_scale=input_scale)
# Requantize with max scale
weight_scale, weight = requantize_with_max_scale(
weight=weight,
weight_scale=weight_scale,
logical_widths=logical_widths,
)
weight = _maybe_pad_fp8_weight(weight)
return weight, weight_scale, input_scale
def process_fp8_weight_channel_strategy(
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Process weights for channel-wise quantization strategy."""
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
normalize_e4m3fn_to_e4m3fnuz)
if current_platform.is_fp8_fnuz():
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight_scale=weight_scale, input_scale=input_scale)
return weight, weight_scale, input_scale
def process_fp8_weight_block_strategy(
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Process weights for block-wise quantization strategy."""
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
normalize_e4m3fn_to_e4m3fnuz)
if current_platform.is_fp8_fnuz():
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight_scale=weight_scale)
weight = _maybe_pad_fp8_weight(weight)
return weight, weight_scale
def maybe_post_process_fp8_weight_block(layer: torch.nn.Module,
cutlass_block_fp8_supported: bool):
assert layer.weight_block_size is not None
from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used,
should_use_deepgemm_for_fp8_linear)
# On Blackwell or Hopper, if E8M0 for DeepGemm is used, we need to
# requantize the weight and input to the specific scale
# at the same time.
if is_deep_gemm_e8m0_used():
block_sz = tuple(layer.weight_block_size)
requant_weight_ue8m0_inplace(layer.weight.data,
layer.weight_scale.data, block_sz)
# SM90 Block FP8 CUTLASS requires row-major weight scales
elif (current_platform.is_device_capability(90)
and cutlass_block_fp8_supported
and not should_use_deepgemm_for_fp8_linear(torch.bfloat16,
layer.weight)):
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data.T.contiguous(), requires_grad=False)
def apply_fp8_block_linear(layer: torch.nn.Module, input: torch.Tensor,
bias: Optional[torch.Tensor],
cutlass_block_fp8_supported: bool,
use_aiter_and_is_supported: bool) -> torch.Tensor:
"""Apply block-wise FP8 linear operation."""
assert layer.weight_block_size is not None
return torch.ops.vllm.apply_w8a8_block_fp8_linear(
input=input,
weight=layer.weight,
block_size=layer.weight_block_size,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
cutlass_block_fp8_supported=cutlass_block_fp8_supported,
use_aiter_and_is_supported=use_aiter_and_is_supported,
)