mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:35:24 +08:00
Refactor dense FP8 tensor/channel/block utils and add CT FP8 block (#21404)
This commit is contained in:
parent
470484a4f5
commit
fbd6523ac0
@ -805,12 +805,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
assert loaded_shard_id < len(self.output_sizes)
|
||||
|
||||
if isinstance(param, BlockQuantScaleParameter):
|
||||
from vllm.model_executor.layers.quantization.fp8 import (
|
||||
Fp8LinearMethod, Fp8MoEMethod)
|
||||
assert self.quant_method is not None
|
||||
assert isinstance(self.quant_method,
|
||||
(Fp8LinearMethod, Fp8MoEMethod))
|
||||
weight_block_size = self.quant_method.quant_config.weight_block_size
|
||||
# Assume the weight block size has been set by quant method
|
||||
assert hasattr(self, "weight_block_size")
|
||||
weight_block_size = self.weight_block_size
|
||||
assert weight_block_size is not None
|
||||
block_n, _ = weight_block_size[0], weight_block_size[1]
|
||||
shard_offset = (
|
||||
@ -989,8 +987,10 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
# Note(simon): This is needed for Qwen3's fp8 quantization.
|
||||
if isinstance(param, BlockQuantScaleParameter):
|
||||
assert self.quant_method is not None
|
||||
assert hasattr(self.quant_method, "quant_config")
|
||||
weight_block_size = self.quant_method.quant_config.weight_block_size
|
||||
# Assume the weight block size has been set by quant method
|
||||
assert hasattr(self, "weight_block_size")
|
||||
weight_block_size = self.weight_block_size
|
||||
assert weight_block_size is not None
|
||||
block_n, _ = weight_block_size[0], weight_block_size[1]
|
||||
shard_offset = (shard_offset + block_n - 1) // block_n
|
||||
shard_size = (shard_size + block_n - 1) // block_n
|
||||
|
||||
@ -12,7 +12,6 @@ from compressed_tensors.quantization import (QuantizationArgs,
|
||||
QuantizationStrategy,
|
||||
QuantizationType)
|
||||
from compressed_tensors.transform import TransformConfig
|
||||
from pydantic import BaseModel
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
@ -268,7 +267,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
else:
|
||||
return False
|
||||
|
||||
def _is_fp4a4_nvfp4(self, weight_quant: BaseModel, input_quant: BaseModel):
|
||||
def _is_fp4a4_nvfp4(self, weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs):
|
||||
|
||||
if weight_quant is None or input_quant is None:
|
||||
return False
|
||||
@ -288,8 +288,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
return (is_tensor_group_quant and is_float_type and is_4_bits
|
||||
and is_group_size_16 and is_symmetric)
|
||||
|
||||
def _is_fp4a16_nvfp4(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel):
|
||||
def _is_fp4a16_nvfp4(self, weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs):
|
||||
|
||||
is_weight_only = weight_quant is not None and input_quant is None
|
||||
is_tensor_group_quant = (
|
||||
@ -303,8 +303,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
return (is_weight_only and is_tensor_group_quant and is_float_type
|
||||
and is_4_bits and is_group_size_16 and is_symmetric)
|
||||
|
||||
def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
def _is_static_tensor_w8a8(self, weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs) -> bool:
|
||||
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
|
||||
weight_strategy = (
|
||||
weight_quant.strategy == QuantizationStrategy.TENSOR.value
|
||||
@ -317,8 +317,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
# Only symmetric weight quantization supported.
|
||||
return is_8_bits and is_tensor and weight_quant.symmetric and is_static
|
||||
|
||||
def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
def _is_dynamic_token_w8a8(self, weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs) -> bool:
|
||||
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
|
||||
weight_strategy = (
|
||||
weight_quant.strategy == QuantizationStrategy.TENSOR.value
|
||||
@ -331,8 +331,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
# Only symmetric weight quantization supported.
|
||||
return is_8_bits and is_token and weight_quant.symmetric and is_dynamic
|
||||
|
||||
def _is_dynamic_token_w4a8_int(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
def _is_dynamic_token_w4a8_int(self, weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs) -> bool:
|
||||
is_weight_4_bits = weight_quant.num_bits == 4
|
||||
is_activation_8_bits = input_quant.num_bits == 8
|
||||
weight_strategy = (
|
||||
@ -347,8 +347,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
return (is_weight_4_bits and is_activation_8_bits and is_token
|
||||
and weight_quant.symmetric and is_dynamic)
|
||||
|
||||
def _is_fp8_w8a8(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
def _is_fp8_w8a8(self, weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs) -> bool:
|
||||
# Confirm weights and activations quantized.
|
||||
if weight_quant is None or input_quant is None:
|
||||
return False
|
||||
@ -358,11 +358,12 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
and input_quant.type == QuantizationType.FLOAT)
|
||||
is_symmetric_weight = weight_quant.symmetric
|
||||
is_static_weight = not weight_quant.dynamic
|
||||
is_per_tensor_or_channel_weight = (weight_quant.strategy in [
|
||||
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL
|
||||
is_tensor_or_channel_or_block_weight = (weight_quant.strategy in [
|
||||
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL,
|
||||
QuantizationStrategy.BLOCK
|
||||
])
|
||||
if not (is_floating_point and is_symmetric_weight and is_static_weight
|
||||
and is_per_tensor_or_channel_weight):
|
||||
and is_tensor_or_channel_or_block_weight):
|
||||
return False
|
||||
|
||||
# Dynamic quantization is always supported if weights supported.
|
||||
@ -375,8 +376,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
input_quant.strategy == QuantizationStrategy.TENSOR)
|
||||
return is_symmetric_activation and is_per_tensor_activation
|
||||
|
||||
def _is_fp8_w4a8(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
def _is_fp8_w4a8(self, weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs) -> bool:
|
||||
if not weight_quant or not input_quant:
|
||||
return False
|
||||
is_weight_4_bits = weight_quant.num_bits == 4
|
||||
@ -392,24 +393,24 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
return (is_weight_4_bits and is_activation_8_bits and is_token
|
||||
and is_symmetric and is_dynamic)
|
||||
|
||||
def _is_fp8_w4a8_sm90(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
def _is_fp8_w4a8_sm90(self, weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs) -> bool:
|
||||
return (self._check_scheme_supported(90, error=False, match_exact=True)
|
||||
and self._is_fp8_w4a8(weight_quant, input_quant))
|
||||
|
||||
def _is_fp8_w8a8_sm90(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
def _is_fp8_w8a8_sm90(self, weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs) -> bool:
|
||||
return (self._check_scheme_supported(90, error=False, match_exact=True)
|
||||
and self._is_fp8_w8a8(weight_quant, input_quant))
|
||||
|
||||
def _is_fp8_w8a8_sm100(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
def _is_fp8_w8a8_sm100(self, weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs) -> bool:
|
||||
return (self._check_scheme_supported(
|
||||
100, error=False, match_exact=True)
|
||||
and self._is_fp8_w8a8(weight_quant, input_quant))
|
||||
|
||||
def _is_fp8_w8a16(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
def _is_fp8_w8a16(self, weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs) -> bool:
|
||||
# Confirm weights quantized.
|
||||
if weight_quant is None:
|
||||
return False
|
||||
@ -421,18 +422,19 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
# Confirm weight scheme is supported.
|
||||
is_symmetric_weight = weight_quant.symmetric
|
||||
is_static_weight = not weight_quant.dynamic
|
||||
is_per_tensor_or_channel_weight = (weight_quant.strategy in [
|
||||
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL
|
||||
is_tensor_or_channel_or_block_weight = (weight_quant.strategy in [
|
||||
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL,
|
||||
QuantizationStrategy.BLOCK
|
||||
])
|
||||
if not (is_symmetric_weight and is_static_weight # noqa: SIM103
|
||||
and is_per_tensor_or_channel_weight):
|
||||
and is_tensor_or_channel_or_block_weight):
|
||||
return False
|
||||
|
||||
# All conditions satisfied.
|
||||
return True
|
||||
|
||||
def _is_wNa16_group_channel(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
def _is_wNa16_group_channel(self, weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs) -> bool:
|
||||
input_quant_none = input_quant is None
|
||||
is_channel_group = (
|
||||
weight_quant.strategy == QuantizationStrategy.CHANNEL.value
|
||||
@ -443,8 +445,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
def _get_scheme_from_parts(
|
||||
self,
|
||||
weight_quant: BaseModel,
|
||||
input_quant: BaseModel,
|
||||
weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs,
|
||||
format: Optional[str] = None) -> "CompressedTensorsScheme":
|
||||
|
||||
# use the per-layer format if defined, otherwise, use global format
|
||||
@ -496,7 +498,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
CompressedTensorsW8A8Fp8.get_min_capability(), error=False)
|
||||
if is_fp8_w8a8_supported:
|
||||
return CompressedTensorsW8A8Fp8(
|
||||
strategy=weight_quant.strategy,
|
||||
weight_quant=weight_quant,
|
||||
is_static_input_scheme=(input_quant
|
||||
and not input_quant.dynamic))
|
||||
else:
|
||||
|
||||
@ -4,28 +4,41 @@
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
from compressed_tensors.quantization import (QuantizationArgs,
|
||||
QuantizationStrategy)
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
apply_fp8_block_linear, check_aiter_fp8_linear_support,
|
||||
create_fp8_input_scale, create_fp8_scale_parameter,
|
||||
create_fp8_weight_parameter, maybe_post_process_fp8_weight_block,
|
||||
process_fp8_weight_block_strategy, process_fp8_weight_channel_strategy,
|
||||
process_fp8_weight_tensor_strategy, validate_fp8_block_shape)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz,
|
||||
requantize_with_max_scale)
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
Fp8LinearOp, cutlass_block_fp8_supported, maybe_create_device_identity)
|
||||
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
__all__ = ["CompressedTensorsW8A8Fp8"]
|
||||
|
||||
strategy_to_parameter_type = {
|
||||
QuantizationStrategy.BLOCK: BlockQuantScaleParameter,
|
||||
QuantizationStrategy.CHANNEL: ChannelQuantScaleParameter,
|
||||
QuantizationStrategy.TENSOR: PerTensorScaleParameter,
|
||||
}
|
||||
|
||||
|
||||
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
|
||||
def __init__(self, strategy: str, is_static_input_scheme: bool):
|
||||
self.strategy = strategy
|
||||
def __init__(self, weight_quant: QuantizationArgs,
|
||||
is_static_input_scheme: bool):
|
||||
self.weight_quant = weight_quant
|
||||
self.strategy = weight_quant.strategy
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.act_q_group_shape = GroupShape.PER_TENSOR \
|
||||
@ -34,61 +47,84 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
act_quant_static=self.is_static_input_scheme,
|
||||
act_quant_group_shape=self.act_q_group_shape)
|
||||
|
||||
self.weight_block_size = self.weight_quant.block_structure
|
||||
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
|
||||
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# lovelace and up
|
||||
return 89
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
weight_loader: Callable, **kwargs):
|
||||
maybe_create_device_identity()
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.weight_block_size = None
|
||||
|
||||
if self.strategy == QuantizationStrategy.BLOCK:
|
||||
assert self.weight_block_size is not None
|
||||
layer.weight_block_size = self.weight_block_size
|
||||
# Validate block quantization shapes
|
||||
validate_fp8_block_shape(layer, input_size, output_size,
|
||||
input_size_per_partition,
|
||||
output_partition_sizes,
|
||||
self.weight_block_size)
|
||||
|
||||
# WEIGHT
|
||||
weight = create_fp8_weight_parameter(output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
weight_loader)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
weight_scale = create_fp8_scale_parameter(
|
||||
strategy_to_parameter_type[self.strategy], output_partition_sizes,
|
||||
input_size_per_partition, layer.weight_block_size, weight_loader)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = create_fp8_input_scale(output_partition_sizes,
|
||||
weight_loader)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
# If per tensor, when we have a fused module (e.g. QKV) with per
|
||||
# tensor scales (thus N scales being passed to the kernel),
|
||||
# requantize so we can always run per tensor
|
||||
if self.strategy == QuantizationStrategy.TENSOR:
|
||||
max_w_scale, weight = requantize_with_max_scale(
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
logical_widths=layer.logical_widths,
|
||||
)
|
||||
weight, weight_scale, input_scale = (
|
||||
process_fp8_weight_tensor_strategy(
|
||||
layer.weight, layer.weight_scale, layer.logical_widths,
|
||||
getattr(layer, 'input_scale', None)))
|
||||
weight = weight.t()
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
input_scale = getattr(layer, 'input_scale', None)
|
||||
|
||||
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
weight_scale=max_w_scale,
|
||||
input_scale=input_scale)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale,
|
||||
requires_grad=False)
|
||||
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||
|
||||
# If channelwise, scales are already lined up, so just transpose.
|
||||
elif self.strategy == QuantizationStrategy.CHANNEL:
|
||||
weight = layer.weight
|
||||
weight, weight_scale, input_scale = (
|
||||
process_fp8_weight_channel_strategy(
|
||||
layer.weight, layer.weight_scale,
|
||||
getattr(layer, 'input_scale', None)))
|
||||
weight = weight.t()
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
input_scale = getattr(layer, 'input_scale', None)
|
||||
|
||||
weight, weight_scale, input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=input_scale)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale,
|
||||
requires_grad=False)
|
||||
else:
|
||||
weight_scale = layer.weight_scale.data
|
||||
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
elif self.strategy == QuantizationStrategy.BLOCK:
|
||||
assert self.is_static_input_scheme is False
|
||||
weight, weight_scale = process_fp8_weight_block_strategy(
|
||||
layer.weight, layer.weight_scale)
|
||||
input_scale = None
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown quantization strategy {self.strategy}")
|
||||
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
layer.weight = Parameter(weight.data, requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale.data, requires_grad=False)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale.data,
|
||||
requires_grad=False)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme and hasattr(layer, 'input_scale'):
|
||||
layer.input_scale = Parameter(layer.input_scale.max(),
|
||||
@ -96,58 +132,23 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
else:
|
||||
layer.input_scale = None
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
maybe_create_device_identity()
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
# WEIGHT
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=torch.float8_e4m3fn),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
# TODO: update create_xxx_parameter functions to return
|
||||
# the newly added parameters
|
||||
if self.strategy == QuantizationStrategy.CHANNEL:
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes), 1),
|
||||
dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
else:
|
||||
assert self.strategy == QuantizationStrategy.TENSOR
|
||||
weight_scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
|
||||
# min requirement for fp8 kernels
|
||||
weight_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
input_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
if self.strategy == QuantizationStrategy.BLOCK:
|
||||
maybe_post_process_fp8_weight_block(
|
||||
layer, self.cutlass_block_fp8_supported)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
if layer.weight_block_size is not None:
|
||||
return apply_fp8_block_linear(
|
||||
layer,
|
||||
input=x,
|
||||
bias=bias,
|
||||
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported)
|
||||
|
||||
return self.fp8_linear.apply(input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
|
||||
@ -4,7 +4,6 @@
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
@ -32,8 +31,12 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights,
|
||||
select_cutlass_fp8_gemm_impl, swap_w13_to_w31)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace,
|
||||
should_use_deepgemm_for_fp8_linear)
|
||||
apply_fp8_block_linear, check_aiter_fp8_linear_support,
|
||||
create_fp8_input_scale, create_fp8_scale_parameter,
|
||||
create_fp8_weight_parameter, get_col_major_tma_aligned_tensor,
|
||||
maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy,
|
||||
process_fp8_weight_tensor_strategy, requant_weight_ue8m0_inplace,
|
||||
validate_fp8_block_shape)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
|
||||
prepare_moe_fp8_layer_for_marlin)
|
||||
@ -42,8 +45,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported,
|
||||
cutlass_fp8_supported, maybe_create_device_identity,
|
||||
normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
|
||||
requantize_with_max_scale)
|
||||
normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
|
||||
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
@ -233,14 +235,10 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
if current_platform.is_rocm():
|
||||
self.use_marlin = False
|
||||
|
||||
# AITER is only supported on ROCm and only for FP8_FNUZ
|
||||
# and at the moment are MI300 series
|
||||
self.use_aiter_and_is_supported = (current_platform.is_rocm()
|
||||
and envs.VLLM_ROCM_USE_AITER
|
||||
and envs.VLLM_ROCM_USE_AITER_LINEAR
|
||||
and current_platform.is_fp8_fnuz())
|
||||
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
|
||||
|
||||
self.block_quant = self.quant_config.weight_block_size is not None
|
||||
self.weight_block_size = self.quant_config.weight_block_size
|
||||
self.block_quant = self.weight_block_size is not None
|
||||
self.act_q_static = self.quant_config.activation_scheme == "static"
|
||||
# Use per-token quantization for better perf if dynamic and cutlass
|
||||
if not self.act_q_static and cutlass_fp8_supported():
|
||||
@ -273,51 +271,27 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
layer.weight_block_size = None
|
||||
|
||||
if self.block_quant:
|
||||
tp_size = getattr(layer, "tp_size",
|
||||
get_tensor_model_parallel_world_size())
|
||||
assert self.quant_config.weight_block_size is not None
|
||||
layer.weight_block_size = self.quant_config.weight_block_size
|
||||
block_n, block_k = (
|
||||
self.quant_config.weight_block_size[0],
|
||||
self.quant_config.weight_block_size[1],
|
||||
)
|
||||
# Required by row parallel
|
||||
if (tp_size > 1
|
||||
and input_size // input_size_per_partition == tp_size
|
||||
and input_size_per_partition % block_k != 0):
|
||||
raise ValueError(
|
||||
f"Weight input_size_per_partition = "
|
||||
f"{input_size_per_partition} is not divisible by "
|
||||
f"weight quantization block_k = {block_k}.")
|
||||
# Required by column parallel or enabling merged weights
|
||||
is_tp_split = (tp_size > 1 and
|
||||
output_size // output_size_per_partition == tp_size)
|
||||
is_merged_gemm = len(output_partition_sizes) > 1
|
||||
if is_tp_split or is_merged_gemm:
|
||||
sizes_to_check = output_partition_sizes
|
||||
if not is_tp_split and is_merged_gemm:
|
||||
# In case of merged matrices, we allow the last
|
||||
# matrix to not be a multiple of block size
|
||||
sizes_to_check = output_partition_sizes[:-1]
|
||||
for output_partition_size in sizes_to_check:
|
||||
if output_partition_size % block_n != 0:
|
||||
raise ValueError(
|
||||
f"Weight output_partition_size = "
|
||||
f"{output_partition_size} is not divisible by "
|
||||
f"weight quantization block_n = {block_n}.")
|
||||
assert self.weight_block_size is not None
|
||||
layer.weight_block_size = self.weight_block_size
|
||||
validate_fp8_block_shape(layer, input_size, output_size,
|
||||
input_size_per_partition,
|
||||
output_partition_sizes,
|
||||
self.weight_block_size)
|
||||
|
||||
# WEIGHT
|
||||
weight_dtype = (torch.float8_e4m3fn
|
||||
if self.quant_config.is_checkpoint_fp8_serialized else
|
||||
params_dtype)
|
||||
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=weight_dtype),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
weight = create_fp8_weight_parameter(output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
weight_loader)
|
||||
else:
|
||||
# For non-serialized checkpoints, use original dtype
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# If checkpoint is serialized fp8, load them.
|
||||
@ -325,154 +299,87 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# WEIGHT SCALE
|
||||
if not self.block_quant:
|
||||
scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes),
|
||||
dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
scale = create_fp8_scale_parameter(PerTensorScaleParameter,
|
||||
output_partition_sizes,
|
||||
input_size_per_partition,
|
||||
None, weight_loader)
|
||||
set_weight_attrs(scale, {"scale_type": "weight_scale"})
|
||||
layer.register_parameter("weight_scale", scale)
|
||||
else:
|
||||
assert self.quant_config.activation_scheme == "dynamic"
|
||||
scale = BlockQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
(output_size_per_partition + block_n - 1) // block_n,
|
||||
(input_size_per_partition + block_k - 1) // block_k,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
assert not self.act_q_static
|
||||
assert self.weight_block_size is not None
|
||||
scale = create_fp8_scale_parameter(BlockQuantScaleParameter,
|
||||
output_partition_sizes,
|
||||
input_size_per_partition,
|
||||
self.weight_block_size,
|
||||
weight_loader)
|
||||
set_weight_attrs(scale, {"scale_type": "weight_scale"})
|
||||
# The weight_scale_inv name is intentional for deepseekv3
|
||||
layer.register_parameter("weight_scale_inv", scale)
|
||||
|
||||
# INPUT ACTIVATION SCALE
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
if self.act_q_static:
|
||||
scale = create_fp8_input_scale(output_partition_sizes,
|
||||
weight_loader)
|
||||
set_weight_attrs(scale, {"scale_type": "input_scale"})
|
||||
layer.register_parameter("input_scale", scale)
|
||||
else:
|
||||
layer.register_parameter("input_scale", None)
|
||||
|
||||
def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
|
||||
# Pad the weight tensor. This is an optimization on ROCm platform, which
|
||||
# can benefit from tensors located far enough from one another in memory
|
||||
if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm()
|
||||
and weight.stride(-1) == 1
|
||||
and (weight.stride(-2) * weight.element_size()) % 512 == 0):
|
||||
num_pad = 256 // weight.element_size()
|
||||
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
|
||||
torch.cuda.empty_cache()
|
||||
return weight
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
size_k_first = True
|
||||
input_scale = None
|
||||
# TODO(rob): refactor block quant into separate class.
|
||||
if self.block_quant:
|
||||
assert self.quant_config.activation_scheme == "dynamic"
|
||||
assert not self.act_q_static
|
||||
size_k_first = False
|
||||
if current_platform.is_fp8_fnuz():
|
||||
weight, weight_scale_inv, _ = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale_inv)
|
||||
else:
|
||||
weight = layer.weight.data
|
||||
weight_scale_inv = layer.weight_scale_inv.data
|
||||
|
||||
weight = self._maybe_pad_weight(weight)
|
||||
|
||||
# Torch.compile cannot use Parameter subclasses.
|
||||
layer.weight = Parameter(weight, requires_grad=False)
|
||||
layer.weight_scale_inv = Parameter(weight_scale_inv,
|
||||
requires_grad=False)
|
||||
weight, weight_scale = process_fp8_weight_block_strategy(
|
||||
layer.weight, layer.weight_scale_inv)
|
||||
# Delete the weight_scale_inv parameter to avoid confusion
|
||||
# with the weight_scale parameter
|
||||
del layer.weight_scale_inv
|
||||
|
||||
# If checkpoint not serialized fp8, quantize the weights.
|
||||
elif not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
|
||||
scale=None)
|
||||
weight = qweight.t()
|
||||
|
||||
# Update the layer with the new values.
|
||||
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
# layer.input_scale is None indicates dynamic quant and scale is
|
||||
# computed from input.
|
||||
layer.input_scale = None
|
||||
|
||||
# If checkpoint is fp8, handle that there are N scales for N
|
||||
# If checkpoint is fp8 per-tensor, handle that there are N scales for N
|
||||
# shards in a fused module
|
||||
else:
|
||||
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
|
||||
requires_grad=False)
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
|
||||
requires_grad=False)
|
||||
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale
|
||||
|
||||
# If using w8a8, torch._scaled_mm needs per tensor, so
|
||||
# requantize the logical shards as a single weight.
|
||||
if not self.use_marlin:
|
||||
# Dequant -> Quant with max scale so we can run per tensor.
|
||||
if current_platform.is_fp8_fnuz():
|
||||
weight, weight_scale, input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
weight_scale=weight_scale,
|
||||
input_scale=layer.input_scale)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale,
|
||||
requires_grad=False)
|
||||
weight, weight_scale, input_scale = (
|
||||
process_fp8_weight_tensor_strategy(
|
||||
weight, weight_scale, layer.logical_widths,
|
||||
getattr(layer, 'input_scale', None)))
|
||||
if self.act_q_static:
|
||||
assert input_scale is not None
|
||||
input_scale = input_scale.max()
|
||||
weight = weight.t()
|
||||
|
||||
weight_scale, weight = requantize_with_max_scale(
|
||||
weight=weight,
|
||||
weight_scale=weight_scale,
|
||||
logical_widths=layer.logical_widths,
|
||||
)
|
||||
|
||||
weight = self._maybe_pad_weight(weight)
|
||||
# Update layer with new values.
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
layer.input_scale = Parameter(layer.input_scale.max(),
|
||||
requires_grad=False)
|
||||
# Update layer with new values.
|
||||
layer.weight = Parameter(weight.data, requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale.data, requires_grad=False)
|
||||
layer.input_scale = Parameter(
|
||||
input_scale,
|
||||
requires_grad=False) if input_scale is not None else None
|
||||
|
||||
if self.use_marlin:
|
||||
prepare_fp8_layer_for_marlin(layer, size_k_first)
|
||||
# Activations not quantized for marlin.
|
||||
del layer.input_scale
|
||||
return
|
||||
|
||||
# On Blackwell or Hopper, if E8M0 for DeepGemm is used, we need to
|
||||
# requantize the weight and input to the specific scale
|
||||
# at the same time.
|
||||
if is_deep_gemm_e8m0_used() and self.block_quant:
|
||||
assert layer.weight_block_size is not None
|
||||
block_sz = tuple(layer.weight_block_size)
|
||||
requant_weight_ue8m0_inplace(
|
||||
layer.weight.data,
|
||||
layer.weight_scale_inv.data if hasattr(
|
||||
layer, "weight_scale_inv") else layer.weight_scale.data,
|
||||
block_sz,
|
||||
)
|
||||
|
||||
# SM90 Block FP8 CUTLASS requires row-major weight scales
|
||||
if (self.block_quant and current_platform.is_device_capability(90)
|
||||
and self.cutlass_block_fp8_supported
|
||||
and not should_use_deepgemm_for_fp8_linear(
|
||||
torch.bfloat16, layer.weight)):
|
||||
layer.weight_scale_inv = Parameter(
|
||||
layer.weight_scale_inv.data.T.contiguous(),
|
||||
requires_grad=False)
|
||||
if self.block_quant:
|
||||
maybe_post_process_fp8_weight_block(
|
||||
layer, self.cutlass_block_fp8_supported)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
@ -490,18 +397,12 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
bias=bias)
|
||||
|
||||
if self.block_quant:
|
||||
assert self.quant_config.weight_block_size is not None
|
||||
|
||||
return torch.ops.vllm.apply_w8a8_block_fp8_linear(
|
||||
return apply_fp8_block_linear(
|
||||
layer,
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
block_size=self.quant_config.weight_block_size,
|
||||
weight_scale=layer.weight_scale_inv,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
||||
)
|
||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported)
|
||||
|
||||
return self.fp8_linear.apply(input=x,
|
||||
weight=layer.weight,
|
||||
@ -528,7 +429,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
super().__init__(layer.moe_config)
|
||||
self.layer = layer
|
||||
self.quant_config = quant_config
|
||||
self.block_quant = self.quant_config.weight_block_size is not None
|
||||
self.weight_block_size = self.quant_config.weight_block_size
|
||||
self.block_quant = self.weight_block_size is not None
|
||||
|
||||
self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
|
||||
self.fused_experts: Optional[
|
||||
@ -590,12 +492,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
if self.block_quant:
|
||||
assert self.quant_config.weight_block_size is not None
|
||||
layer.weight_block_size = self.quant_config.weight_block_size
|
||||
assert self.weight_block_size is not None
|
||||
layer.weight_block_size = self.weight_block_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
block_n, block_k = (
|
||||
self.quant_config.weight_block_size[0],
|
||||
self.quant_config.weight_block_size[1],
|
||||
self.weight_block_size[0],
|
||||
self.weight_block_size[1],
|
||||
)
|
||||
# NOTE: To ensure proper alignment of the block-wise quantization
|
||||
# scales, the output_size of the weights for both the gate and up
|
||||
@ -952,7 +854,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
"BatchedTritonOrDeepGemmExperts(%s): "
|
||||
"max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
|
||||
self.__class__.__name__, max_num_tokens_per_rank,
|
||||
self.quant_config.weight_block_size, False)
|
||||
self.weight_block_size, False)
|
||||
return BatchedTritonOrDeepGemmExperts(
|
||||
max_num_tokens=max_num_tokens_per_rank,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
@ -969,8 +871,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
else:
|
||||
logger.debug(
|
||||
"TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
|
||||
self.__class__.__name__, self.quant_config.weight_block_size,
|
||||
False)
|
||||
self.__class__.__name__, self.weight_block_size, False)
|
||||
return TritonOrDeepGemmExperts(
|
||||
quant_config=self.moe_quant_config,
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
@ -988,7 +889,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
if self.block_quant else layer.w2_weight_scale),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
block_shape=self.weight_block_size,
|
||||
)
|
||||
|
||||
def apply(
|
||||
@ -1046,7 +947,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
intermediate_size=layer.intermediate_size_per_partition,
|
||||
expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
block_shape=self.weight_block_size,
|
||||
routed_scaling=routed_scaling_factor,
|
||||
)
|
||||
else:
|
||||
|
||||
@ -17,6 +17,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
group_broadcast)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
CUTLASS_BLOCK_FP8_SUPPORTED)
|
||||
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import cdiv, direct_register_custom_op
|
||||
@ -794,3 +797,220 @@ def requant_weight_ue8m0_inplace(
|
||||
# Write back the results in-place.
|
||||
w_q.copy_(w_requant)
|
||||
s_old.copy_(s_requant)
|
||||
|
||||
|
||||
def check_aiter_fp8_linear_support() -> bool:
|
||||
"""AITER is only supported on ROCm and only for FP8_FNUZ
|
||||
and at the moment are MI300 series"""
|
||||
return (current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER
|
||||
and envs.VLLM_ROCM_USE_AITER_LINEAR
|
||||
and current_platform.is_fp8_fnuz())
|
||||
|
||||
|
||||
def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor:
|
||||
"""Pad the weight tensor. This is an optimization on ROCm platform, which
|
||||
can benefit from tensors located far enough from one another in memory"""
|
||||
if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm()
|
||||
and weight.stride(-1) == 1
|
||||
and (weight.stride(-2) * weight.element_size()) % 512 == 0):
|
||||
num_pad = 256 // weight.element_size()
|
||||
import torch.nn.functional as F
|
||||
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
|
||||
torch.cuda.empty_cache()
|
||||
return weight
|
||||
|
||||
|
||||
def validate_fp8_block_shape(layer: torch.nn.Module, input_size: int,
|
||||
output_size: int, input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
block_size: list[int]) -> None:
|
||||
"""Validate block quantization shapes for tensor parallelism."""
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
|
||||
tp_size = getattr(layer, "tp_size", get_tensor_model_parallel_world_size())
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
|
||||
# Required by row parallel
|
||||
if (tp_size > 1 and input_size // input_size_per_partition == tp_size
|
||||
and input_size_per_partition % block_k != 0):
|
||||
raise ValueError(
|
||||
f"Weight input_size_per_partition = {input_size_per_partition} "
|
||||
f"is not divisible by weight quantization block_k = {block_k}.")
|
||||
|
||||
# Required by column parallel or enabling merged weights
|
||||
is_tp_split = (tp_size > 1
|
||||
and output_size // sum(output_partition_sizes) == tp_size)
|
||||
is_merged_gemm = len(output_partition_sizes) > 1
|
||||
if is_tp_split or is_merged_gemm:
|
||||
sizes_to_check = output_partition_sizes
|
||||
if not is_tp_split and is_merged_gemm:
|
||||
# In case of merged matrices, we allow the last
|
||||
# matrix to not be a multiple of block size
|
||||
sizes_to_check = output_partition_sizes[:-1]
|
||||
for output_partition_size in sizes_to_check:
|
||||
if output_partition_size % block_n != 0:
|
||||
raise ValueError(
|
||||
f"Weight output_partition_size = "
|
||||
f"{output_partition_size} is not divisible by "
|
||||
f"weight quantization block_n = {block_n}.")
|
||||
|
||||
|
||||
def create_fp8_weight_parameter(
|
||||
output_size_per_partition: int, input_size_per_partition: int,
|
||||
weight_loader: Optional[Callable]) -> torch.nn.Parameter:
|
||||
"""Create FP8 weight parameter."""
|
||||
from vllm.model_executor.parameter import ModelWeightParameter
|
||||
|
||||
return ModelWeightParameter(data=torch.empty(output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=torch.float8_e4m3fn),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
|
||||
def create_fp8_scale_parameter(
|
||||
parameter_type: torch.nn.Parameter, output_partition_sizes: list[int],
|
||||
input_size_per_partition: int, block_size: Optional[list[int]],
|
||||
weight_loader: Optional[Callable]) -> torch.nn.Parameter:
|
||||
"""Create scale parameter based on quantization strategy."""
|
||||
if parameter_type == ChannelQuantScaleParameter:
|
||||
scale = parameter_type(data=torch.empty(
|
||||
(sum(output_partition_sizes), 1), dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
elif parameter_type == BlockQuantScaleParameter:
|
||||
assert block_size is not None
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
scale = parameter_type(
|
||||
data=torch.empty(
|
||||
(output_size_per_partition + block_n - 1) // block_n,
|
||||
(input_size_per_partition + block_k - 1) // block_k,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
elif parameter_type == PerTensorScaleParameter:
|
||||
scale = parameter_type(data=torch.empty(len(output_partition_sizes),
|
||||
dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
else:
|
||||
raise ValueError(f"Unknown parameter type: {parameter_type}")
|
||||
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
return scale
|
||||
|
||||
|
||||
def create_fp8_input_scale(
|
||||
output_partition_sizes: list[int],
|
||||
weight_loader: Optional[Callable]) -> torch.nn.Parameter:
|
||||
"""Create input scale parameter for static activation quantization."""
|
||||
from vllm.model_executor.parameter import PerTensorScaleParameter
|
||||
|
||||
scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
return scale
|
||||
|
||||
|
||||
def process_fp8_weight_tensor_strategy(
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
logical_widths: list[int],
|
||||
input_scale: Optional[torch.Tensor] = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Process weights for tensor-wise quantization strategy."""
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale)
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight, weight_scale=weight_scale, input_scale=input_scale)
|
||||
|
||||
# Requantize with max scale
|
||||
weight_scale, weight = requantize_with_max_scale(
|
||||
weight=weight,
|
||||
weight_scale=weight_scale,
|
||||
logical_widths=logical_widths,
|
||||
)
|
||||
|
||||
weight = _maybe_pad_fp8_weight(weight)
|
||||
return weight, weight_scale, input_scale
|
||||
|
||||
|
||||
def process_fp8_weight_channel_strategy(
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Process weights for channel-wise quantization strategy."""
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
normalize_e4m3fn_to_e4m3fnuz)
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight, weight_scale=weight_scale, input_scale=input_scale)
|
||||
|
||||
return weight, weight_scale, input_scale
|
||||
|
||||
|
||||
def process_fp8_weight_block_strategy(
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Process weights for block-wise quantization strategy."""
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
normalize_e4m3fn_to_e4m3fnuz)
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight, weight_scale=weight_scale)
|
||||
|
||||
weight = _maybe_pad_fp8_weight(weight)
|
||||
return weight, weight_scale
|
||||
|
||||
|
||||
def maybe_post_process_fp8_weight_block(layer: torch.nn.Module,
|
||||
cutlass_block_fp8_supported: bool):
|
||||
assert layer.weight_block_size is not None
|
||||
|
||||
from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used,
|
||||
should_use_deepgemm_for_fp8_linear)
|
||||
|
||||
# On Blackwell or Hopper, if E8M0 for DeepGemm is used, we need to
|
||||
# requantize the weight and input to the specific scale
|
||||
# at the same time.
|
||||
if is_deep_gemm_e8m0_used():
|
||||
block_sz = tuple(layer.weight_block_size)
|
||||
requant_weight_ue8m0_inplace(layer.weight.data,
|
||||
layer.weight_scale.data, block_sz)
|
||||
# SM90 Block FP8 CUTLASS requires row-major weight scales
|
||||
elif (current_platform.is_device_capability(90)
|
||||
and cutlass_block_fp8_supported
|
||||
and not should_use_deepgemm_for_fp8_linear(torch.bfloat16,
|
||||
layer.weight)):
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
layer.weight_scale.data.T.contiguous(), requires_grad=False)
|
||||
|
||||
|
||||
def apply_fp8_block_linear(layer: torch.nn.Module, input: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
cutlass_block_fp8_supported: bool,
|
||||
use_aiter_and_is_supported: bool) -> torch.Tensor:
|
||||
"""Apply block-wise FP8 linear operation."""
|
||||
assert layer.weight_block_size is not None
|
||||
|
||||
return torch.ops.vllm.apply_w8a8_block_fp8_linear(
|
||||
input=input,
|
||||
weight=layer.weight,
|
||||
block_size=layer.weight_block_size,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
cutlass_block_fp8_supported=cutlass_block_fp8_supported,
|
||||
use_aiter_and_is_supported=use_aiter_and_is_supported,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user