From f5e6cd9695848739d56acc46f89b29db8e0769bf Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 4 Nov 2025 12:11:13 +0000 Subject: [PATCH] prefer QuantKey over ScaledMMLinearQuantStrategy Signed-off-by: vllmellm --- tests/compile/test_functionalization.py | 19 ++++----- tests/compile/test_fusion.py | 26 ++++++------ tests/compile/test_fusion_all_reduce.py | 14 +++---- tests/compile/test_fusion_attn.py | 14 ++----- tests/compile/test_sequence_parallelism.py | 12 +++--- tests/compile/test_silu_mul_quant_fusion.py | 12 ++---- .../schemes/compressed_tensors_w8a8_fp8.py | 39 +++++++++--------- .../layers/quantization/fbgemm_fp8.py | 17 +++----- .../model_executor/layers/quantization/fp8.py | 20 +++++----- .../kernels/scaled_mm/ScaledMMLinearKernel.py | 30 +++++--------- .../kernels/scaled_mm/__init__.py | 13 +++--- .../kernels/scaled_mm/flash_infer.py | 7 ++-- .../quantization/kernels/scaled_mm/pytorch.py | 19 +++++---- .../quantization/kernels/scaled_mm/rocm.py | 21 +++++----- .../layers/quantization/modelopt.py | 12 ++---- .../layers/quantization/ptpc_fp8.py | 16 +++----- .../quark/schemes/quark_w8a8_fp8.py | 40 +++++++++---------- .../layers/quantization/utils/quant_utils.py | 3 ++ 18 files changed, 147 insertions(+), 187 deletions(-) diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index 4d979f075d782..a40f8beccdc20 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -23,10 +23,9 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( init_fp8_linear_kernel, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 - ScaledMMLinearQuantStrategy, +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kFp8StaticTensorSym, ) -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.platforms import current_platform @@ -37,6 +36,8 @@ FP8_DTYPE = current_platform.fp8_dtype() class TestSiluMul(torch.nn.Module): + quant_key = kFp8StaticTensorSym + def __init__(self, hidden_size: int = 128): super().__init__() self.silu_and_mul = SiluAndMul() @@ -46,9 +47,8 @@ class TestSiluMul(torch.nn.Module): if TEST_FP8: self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() self.fp8_linear = init_fp8_linear_kernel( - act_q_static=True, - act_q_group_shape=GroupShape.PER_TENSOR, - weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, + activation_quant_key=self.quant_key, + weight_quant_key=self.quant_key, out_dtype=torch.get_default_dtype(), module_name=self.__class__.__name__, ) @@ -74,6 +74,8 @@ class TestSiluMul(torch.nn.Module): class TestFusedAddRMSNorm(torch.nn.Module): + quant_key = kFp8StaticTensorSym + def __init__(self, hidden_size=16, intermediate_size=32): super().__init__() self.hidden_size = hidden_size @@ -89,9 +91,8 @@ class TestFusedAddRMSNorm(torch.nn.Module): if TEST_FP8: self.fp8_linear = init_fp8_linear_kernel( - act_q_static=True, - act_q_group_shape=GroupShape.PER_TENSOR, - weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, + activation_quant_key=self.quant_key, + weight_quant_key=self.quant_key, out_dtype=torch.get_default_dtype(), module_name=self.__class__.__name__, ) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index ed925a4d55cca..6270344c2eb35 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -21,9 +21,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( init_fp8_linear_kernel, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 - ScaledMMLinearQuantStrategy, -) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, QuantKey, @@ -59,10 +56,16 @@ class TestModel(torch.nn.Module): self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)] self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN - weight_quant_strategy = ScaledMMLinearQuantStrategy.TENSOR - quant_scale = ScaleDesc(torch.float32, static, group_shape) - self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True) + act_quant_scale = ScaleDesc(torch.float32, static, group_shape) + w_quant_scale = ScaleDesc(torch.float32, True, group_shape) + self.activation_quant_key = QuantKey( + dtype=FP8_DTYPE, scale=act_quant_scale, symmetric=True + ) + self.weight_quant_key = QuantKey( + dtype=FP8_DTYPE, scale=w_quant_scale, symmetric=True + ) + if static: self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] else: @@ -74,9 +77,8 @@ class TestModel(torch.nn.Module): with override_cutlass_fp8_supported(not cuda_force_torch): self.fp8_linear = init_fp8_linear_kernel( - act_q_static=static, - act_q_group_shape=group_shape, - weight_quant_strategy=weight_quant_strategy, + activation_quant_key=self.activation_quant_key, + weight_quant_key=self.weight_quant_key, out_dtype=torch.get_default_dtype(), module_name=self.__class__.__name__, ) @@ -110,13 +112,13 @@ class TestModel(torch.nn.Module): def ops_in_model_after(self): return [ - FUSED_OPS[FusedRMSQuantKey(self.quant_key, True)], - FUSED_OPS[FusedRMSQuantKey(self.quant_key, False)], + FUSED_OPS[FusedRMSQuantKey(self.activation_quant_key, True)], + FUSED_OPS[FusedRMSQuantKey(self.activation_quant_key, False)], ] def ops_in_model_before(self): return ( - [QUANT_OPS[self.quant_key]] + [QUANT_OPS[self.activation_quant_key]] if self.enable_quant_fp8_custom_op else [torch.ops.aten.reciprocal] ) diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index 2dc6f8d2f925d..5e2c46f8ea919 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -29,11 +29,8 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( init_fp8_linear_kernel, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 - ScaledMMLinearQuantStrategy, -) -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - GroupShape, +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kFp8StaticTensorSym, ) from vllm.platforms import current_platform from vllm.utils.system_utils import update_environment_variables @@ -80,6 +77,8 @@ class TestAllReduceRMSNormModel(torch.nn.Module): class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): + quant_key = kFp8StaticTensorSym + def __init__(self, hidden_size=16, token_num=16, eps=1e-6): super().__init__() self.hidden_size = hidden_size @@ -95,9 +94,8 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): ] self.fp8_linear = init_fp8_linear_kernel( - act_q_static=True, - act_q_group_shape=GroupShape.PER_TENSOR, - weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, + activation_quant_key=self.quant_key, + weight_quant_key=self.quant_key, out_dtype=torch.get_default_dtype(), module_name=self.__class__.__name__, ) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index a6ebf46d98ddb..9068e304f551d 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -31,9 +31,6 @@ from vllm.forward_context import get_forward_context, set_forward_context from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( init_fp8_linear_kernel, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 - ScaledMMLinearQuantStrategy, -) from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8StaticTensorSym, @@ -177,18 +174,13 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - if self.quant_key.scale.group_shape.is_per_tensor(): - weight_quant_strategy = ScaledMMLinearQuantStrategy.TENSOR - else: - weight_quant_strategy = ScaledMMLinearQuantStrategy.CHANNEL - self.fp8_linear = init_fp8_linear_kernel( - act_q_static=self.quant_key.scale.static, - act_q_group_shape=self.quant_key.scale.group_shape, - weight_quant_strategy=weight_quant_strategy, + activation_quant_key=self.quant_key, + weight_quant_key=self.quant_key, out_dtype=torch.get_default_dtype(), module_name=self.__class__.__name__, ) + hidden_size = self.num_qo_heads * self.head_size self.w = kwargs.get( "w", diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index 007339cd86f7b..f579815338a98 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -30,10 +30,9 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( init_fp8_linear_kernel, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 - ScaledMMLinearQuantStrategy, +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kFp8StaticTensorSym, ) -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.platforms import current_platform from vllm.utils.system_utils import update_environment_variables @@ -101,6 +100,8 @@ class TestModel(torch.nn.Module): class TestQuantModel(torch.nn.Module): + quant_key = kFp8StaticTensorSym + def __init__(self, hidden_size=16, intermediate_size=32): super().__init__() self.hidden_size = hidden_size @@ -114,9 +115,8 @@ class TestQuantModel(torch.nn.Module): torch.nn.init.normal_(self.gate_proj, std=0.02) self.fp8_linear = init_fp8_linear_kernel( - act_q_static=True, - act_q_group_shape=GroupShape.PER_TENSOR, - weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, + activation_quant_key=self.quant_key, + weight_quant_key=self.quant_key, out_dtype=torch.get_default_dtype(), module_name=self.__class__.__name__, ) diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index 2ce52b97f13e3..20e7c2955d01b 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -27,11 +27,7 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( init_fp8_linear_kernel, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 - ScaledMMLinearQuantStrategy, -) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, kFp8StaticTensorSym, kNvfp4Quant, ) @@ -52,6 +48,8 @@ def is_nvfp4_supported(): class TestSiluMulFp8QuantModel(torch.nn.Module): + quant_key = kFp8StaticTensorSym + def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs): super().__init__() self.silu_and_mul = SiluAndMul() @@ -62,13 +60,11 @@ class TestSiluMulFp8QuantModel(torch.nn.Module): with override_cutlass_fp8_supported(not cuda_force_torch): self.fp8_linear = init_fp8_linear_kernel( - act_q_static=True, - act_q_group_shape=GroupShape.PER_TENSOR, - weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, + activation_quant_key=self.quant_key, + weight_quant_key=self.quant_key, out_dtype=torch.get_default_dtype(), module_name=self.__class__.__name__, ) - self.enable_silu_mul_custom_op = self.silu_and_mul.enabled() self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 58ea30edcd639..0d14c13180ab1 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -14,9 +14,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( init_fp8_linear_kernel, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 - QUANT_STRATEGY_MAP, -) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, check_aiter_fp8_linear_support, @@ -29,7 +26,11 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( process_fp8_weight_tensor_strategy, validate_fp8_block_shape, ) -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + kFp8DynamicTokenSym, + kFp8StaticTensorSym, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( cutlass_block_fp8_supported, maybe_create_device_identity, @@ -48,6 +49,12 @@ strategy_to_parameter_type = { QuantizationStrategy.TENSOR: PerTensorScaleParameter, } +STATIC_QUANT = True +DYNAMIC_QUANT = False +quant_keys = { + STATIC_QUANT: (kFp8StaticTensorSym, kFp8StaticTensorSym), + DYNAMIC_QUANT: (kFp8DynamicTokenSym, kFp8StaticTensorSym), +} logger = init_logger(__name__) @@ -57,22 +64,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): self.strategy = weight_quant.strategy self.out_dtype = torch.get_default_dtype() self.is_static_input_scheme = is_static_input_scheme - self.weight_block_size = self.weight_quant.block_structure - if self.weight_block_size is not None: - self.act_q_group_shape = GroupShape(1, self.weight_block_size[0]) - else: - self.act_q_group_shape = ( - GroupShape.PER_TENSOR - if is_static_input_scheme - else GroupShape.PER_TOKEN - ) - self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() if self.weight_block_size is not None: assert not self.is_static_input_scheme + self.act_q_group_shape = GroupShape(1, self.weight_block_size[0]) self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( weight_group_shape=GroupShape(*self.weight_block_size), act_quant_group_shape=self.act_q_group_shape, @@ -80,12 +78,11 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): use_aiter_and_is_supported=self.use_aiter_and_is_supported, ) else: - weight_quant_strategy = QUANT_STRATEGY_MAP[self.strategy] - self.fp8_linear_kernel = init_fp8_linear_kernel( - act_q_static=self.is_static_input_scheme, - act_q_group_shape=self.act_q_group_shape, - weight_quant_strategy=weight_quant_strategy, - out_dtype=self.out_dtype, + activation_quant_key, weight_quant_key = quant_keys[is_static_input_scheme] + self.fp8_linear = init_fp8_linear_kernel( + activation_quant_key=activation_quant_key, + weight_quant_key=weight_quant_key, + out_dtype=torch.get_default_dtype(), module_name=self.__class__.__name__, ) @@ -204,4 +201,4 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): bias=bias, ) - return self.fp8_linear_kernel.apply_weights(layer, x, bias) + return self.fp8_linear.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 5fa419ebaa91a..c19dd708b2339 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -21,16 +21,13 @@ from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( init_fp8_linear_kernel, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 - ScaledMMLinearQuantStrategy, -) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin, ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, is_layer_skipped, + kFp8DynamicTokenSym, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( maybe_create_device_identity, @@ -97,12 +94,10 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: FBGEMMFp8Config): self.quant_config = quant_config self.out_dtype = torch.get_default_dtype() - - self.fp8_linear_kernel = init_fp8_linear_kernel( - act_q_static=False, - act_q_group_shape=GroupShape.PER_TOKEN, - weight_quant_strategy=ScaledMMLinearQuantStrategy.CHANNEL, - out_dtype=self.out_dtype, + self.fp8_linear = init_fp8_linear_kernel( + activation_quant_key=kFp8DynamicTokenSym, + weight_quant_key=kFp8DynamicTokenSym, + out_dtype=torch.get_default_dtype(), module_name=self.__class__.__name__, ) @@ -194,4 +189,4 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): bias=bias, ) - return self.fp8_linear_kernel.apply_weights(layer, x, bias) + return self.fp8_linear.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 48697e3849e05..c04bcef7bb0b5 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -45,10 +45,6 @@ from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( init_fp8_linear_kernel, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa E501 - FP8ScaledMMLinearLayerConfig, - ScaledMMLinearQuantStrategy, -) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( FlashinferMoeBackend, @@ -82,6 +78,9 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, is_layer_skipped, + kFp8DynamicTensorSym, + kFp8StaticTensorSym, + kFp8StaticTokenSym, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, @@ -380,8 +379,10 @@ class Fp8LinearMethod(LinearMethodBase): # Use per-token quantization for better perf if dynamic and cutlass if not self.act_q_static and cutlass_fp8_supported(): self.act_q_group_shape = GroupShape.PER_TOKEN + self.activation_quant_key = kFp8StaticTokenSym else: self.act_q_group_shape = GroupShape.PER_TENSOR + self.activation_quant_key = kFp8DynamicTensorSym if self.block_quant: assert not self.act_q_static @@ -393,11 +394,10 @@ class Fp8LinearMethod(LinearMethodBase): use_aiter_and_is_supported=self.use_aiter_and_is_supported, ) else: - self.fp8_linear_kernel = init_fp8_linear_kernel( - act_q_static=self.act_q_static, - act_q_group_shape=self.act_q_group_shape, - weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, - out_dtype=self.out_dtype, + self.fp8_linear = init_fp8_linear_kernel( + activation_quant_key=self.activation_quant_key, + weight_quant_key=kFp8StaticTensorSym, + out_dtype=torch.get_default_dtype(), module_name=self.__class__.__name__, ) @@ -684,7 +684,7 @@ class Fp8LinearMethod(LinearMethodBase): bias=bias, ) - return self.fp8_linear_kernel.apply_weights(layer, x, bias) + return self.fp8_linear.apply_weights(layer, x, bias) class Fp8MoEMethod(FusedMoEMethodBase): diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py index f3bff8cae0ef7..a8a2fc245f62d 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py @@ -4,43 +4,32 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from dataclasses import dataclass -from enum import Enum from typing import Generic, TypeVar import torch -from compressed_tensors.quantization import QuantizationStrategy from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape - - -class ScaledMMLinearQuantStrategy(Enum): - TENSOR = "tensor" - CHANNEL = "channel" - BLOCK = "block" - - -QUANT_STRATEGY_MAP = { - QuantizationStrategy.TENSOR: ScaledMMLinearQuantStrategy.TENSOR, - QuantizationStrategy.CHANNEL: ScaledMMLinearQuantStrategy.CHANNEL, -} +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, +) @dataclass class ScaledMMLinearLayerConfig: - is_static_input_scheme: bool + pass @dataclass class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): + is_static_input_scheme: bool is_channelwise: bool input_symmetric: bool @dataclass class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): - weight_quant_strategy: ScaledMMLinearQuantStrategy - activation_group_shape: GroupShape + weight_quant_key: QuantKey + activation_quant_key: QuantKey out_dtype: torch.dtype | None @@ -103,9 +92,10 @@ class FP8ScaledMMLinearKernel( def __init__( self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str] ) -> None: + act_scale_descriptor = c.activation_quant_key.scale self.quant_fp8 = QuantFP8( - static=c.is_static_input_scheme, - group_shape=c.activation_group_shape, + static=act_scale_descriptor.static, + group_shape=act_scale_descriptor.group_shape, num_token_padding=self.get_ouput_padding(), ) super().__init__(c, layer_param_names) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 901f0649a6d48..b36b77109e922 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -32,7 +32,6 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKer Int8ScaledMMLinearLayerConfig, ScaledMMLinearKernel, ScaledMMLinearLayerConfig, - ScaledMMLinearQuantStrategy, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import ( TritonScaledMMLinearKernel, @@ -40,7 +39,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import ( XLAScaledMMLinearKernel, ) -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey from vllm.platforms import PlatformEnum, current_platform logger = init_logger(__name__) @@ -137,16 +136,14 @@ def choose_scaled_mm_linear_kernel( def init_fp8_linear_kernel( - act_q_static: bool, - act_q_group_shape: GroupShape, - weight_quant_strategy: ScaledMMLinearQuantStrategy, + activation_quant_key: QuantKey, + weight_quant_key: QuantKey, out_dtype: torch.dtype, module_name: str, ) -> FP8ScaledMMLinearKernel: scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig( - is_static_input_scheme=act_q_static, - weight_quant_strategy=weight_quant_strategy, - activation_group_shape=act_q_group_shape, + weight_quant_key=weight_quant_key, + activation_quant_key=activation_quant_key, out_dtype=out_dtype, ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py index 9b0ac38db5e3c..3bac71950dda2 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py @@ -9,7 +9,6 @@ from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer from .ScaledMMLinearKernel import ( FP8ScaledMMLinearKernel, FP8ScaledMMLinearLayerConfig, - ScaledMMLinearQuantStrategy, ) from .utils import apply_weights_fp8 @@ -39,10 +38,10 @@ class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel): @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: - per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() - per_tensor_weight_scales = ( - c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR + per_tensor_activation_scales = ( + c.activation_quant_key.scale.group_shape.is_per_tensor() ) + per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor() if not current_platform.is_cuda(): return ( diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py index c0466e840fc08..7c4c64215a8e9 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py @@ -10,7 +10,6 @@ from vllm.platforms import current_platform from .ScaledMMLinearKernel import ( FP8ScaledMMLinearKernel, FP8ScaledMMLinearLayerConfig, - ScaledMMLinearQuantStrategy, ) from .utils import apply_weights_fp8 @@ -143,10 +142,10 @@ class TorchScaledMMLinearKernel(FP8ScaledMMLinearKernel): class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: - per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() - per_tensor_weight_scales = ( - c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR + per_tensor_activation_scales = ( + c.activation_quant_key.scale.group_shape.is_per_tensor() ) + per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor() if not (per_tensor_activation_scales and per_tensor_weight_scales): return ( @@ -183,10 +182,10 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: - per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() - per_tensor_weight_scales = ( - c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR + per_tensor_activation_scales = ( + c.activation_quant_key.scale.group_shape.is_per_tensor() ) + per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor() if per_tensor_activation_scales or per_tensor_weight_scales: return ( @@ -237,10 +236,10 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: - per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() - per_tensor_weight_scales = ( - c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR + per_tensor_activation_scales = ( + c.activation_quant_key.scale.group_shape.is_per_tensor() ) + per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor() if per_tensor_activation_scales and per_tensor_weight_scales: return ( diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py index 63744337a7e5a..26463a19c6f48 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py @@ -11,7 +11,6 @@ from vllm.utils.torch_utils import direct_register_custom_op from .ScaledMMLinearKernel import ( FP8ScaledMMLinearKernel, FP8ScaledMMLinearLayerConfig, - ScaledMMLinearQuantStrategy, ) from .utils import apply_weights_fp8 @@ -72,18 +71,18 @@ def rocm_per_tensor_float_w8a8_scaled_mm( bias: torch.Tensor, output_shape: list[int], ) -> torch.Tensor: - output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl( + output = torch.ops.vllm.rocm_per_tensor_float_w8a8_scaled_mm_impl( A, B, out_dtype, As, Bs, bias ) return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape) -if current_platform.is_rocm(): - direct_register_custom_op( - op_name="rocm_per_tensor_float_w8a8_scaled_mm_impl", - op_func=rocm_per_tensor_float_w8a8_scaled_mm_impl, - fake_impl=rocm_per_tensor_float_w8a8_scaled_mm_fake, - ) +# if current_platform.is_rocm(): +direct_register_custom_op( + op_name="rocm_per_tensor_float_w8a8_scaled_mm_impl", + op_func=rocm_per_tensor_float_w8a8_scaled_mm_impl, + fake_impl=rocm_per_tensor_float_w8a8_scaled_mm_fake, +) class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel): @@ -95,10 +94,10 @@ class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel): # TODO: check if this causes an issue on non-ROCM platforms from vllm.platforms.rocm import on_mi3xx - per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() - per_tensor_weight_scales = ( - c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR + per_tensor_activation_scales = ( + c.activation_quant_key.scale.group_shape.is_per_tensor() ) + per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor() if not current_platform.is_rocm(): return ( diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index f478cd319e667..53b25af440353 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -40,9 +40,6 @@ from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( init_fp8_linear_kernel, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 - ScaledMMLinearQuantStrategy, -) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( build_flashinfer_fp4_cutlass_moe_prepare_finalize, @@ -68,9 +65,9 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( prepare_moe_fp4_layer_for_marlin, ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, cutlass_fp4_supported, is_layer_skipped, + kFp8StaticTensorSym, swizzle_blockscale, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -260,10 +257,9 @@ class ModelOptFp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: ModelOptFp8Config) -> None: self.quant_config = quant_config self.fp8_linear = init_fp8_linear_kernel( - act_q_static=True, - act_q_group_shape=GroupShape.PER_TENSOR, - weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, - out_dtype=None, + activation_quant_key=kFp8StaticTensorSym, + weight_quant_key=kFp8StaticTensorSym, + out_dtype=torch.get_default_dtype(), module_name=self.__class__.__name__, ) diff --git a/vllm/model_executor/layers/quantization/ptpc_fp8.py b/vllm/model_executor/layers/quantization/ptpc_fp8.py index 2634bbd4bd87e..c102c52bbe3f5 100644 --- a/vllm/model_executor/layers/quantization/ptpc_fp8.py +++ b/vllm/model_executor/layers/quantization/ptpc_fp8.py @@ -19,12 +19,9 @@ from vllm.model_executor.layers.quantization.fp8 import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( init_fp8_linear_kernel, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa E501 - ScaledMMLinearQuantStrategy, -) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, is_layer_skipped, + kFp8DynamicTokenSym, ) from vllm.platforms import current_platform @@ -103,11 +100,10 @@ class PTPCFp8LinearMethod(Fp8LinearMethod): ) super().__init__(quant_config=quant_config) # Force weight quantization - self.fp8_linear_kernel = init_fp8_linear_kernel( - act_q_static=False, - act_q_group_shape=GroupShape.PER_TOKEN, - weight_quant_strategy=ScaledMMLinearQuantStrategy.CHANNEL, - out_dtype=self.out_dtype, + self.fp8_linear = init_fp8_linear_kernel( + activation_quant_key=kFp8DynamicTokenSym, + weight_quant_key=kFp8DynamicTokenSym, + out_dtype=torch.get_default_dtype(), module_name=self.__class__.__name__, ) @@ -135,4 +131,4 @@ class PTPCFp8LinearMethod(Fp8LinearMethod): x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - return self.fp8_linear_kernel.apply_weights(layer, x, bias) + return self.fp8_linear.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index 6fff449000075..343539c10fa87 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -11,11 +11,13 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( init_fp8_linear_kernel, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 - ScaledMMLinearQuantStrategy, -) from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + kFp8DynamicTokenSym, + kFp8StaticTensorSym, + kFp8StaticTokenSym, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale, @@ -31,11 +33,6 @@ __all__ = ["QuarkW8A8Fp8"] logger = init_logger(__name__) -QUANT_STRATEGY_MAP = { - "per_tensor": ScaledMMLinearQuantStrategy.TENSOR, - "per_channel": ScaledMMLinearQuantStrategy.CHANNEL, -} - class QuarkW8A8Fp8(QuarkScheme): def __init__( @@ -48,11 +45,16 @@ class QuarkW8A8Fp8(QuarkScheme): self.is_static_input_scheme = not cast(bool, input_config.get("is_dynamic")) self.input_qscheme = cast(str, input_config.get("qscheme")) - per_token = ( + per_token_activation = ( not self.is_static_input_scheme and self.input_qscheme == "per_channel" ) - self.act_quant_group_shape = ( - GroupShape.PER_TOKEN if per_token else GroupShape.PER_TENSOR + per_token_weight = self.weight_qscheme == "per_channel" + + self.activation_quant_key = ( + kFp8DynamicTokenSym if per_token_activation else kFp8StaticTensorSym + ) + self.weight_quant_key = ( + kFp8StaticTokenSym if per_token_weight else kFp8StaticTensorSym ) self.out_dtype = torch.get_default_dtype() @@ -103,7 +105,7 @@ class QuarkW8A8Fp8(QuarkScheme): layer.input_scale = Parameter(input_scale, requires_grad=False) else: weight_scale = layer.weight_scale.data - if self.act_quant_group_shape == GroupShape.PER_TOKEN: + if self.activation_quant_key.scale.group_shape == GroupShape.PER_TOKEN: weight_scale = weight_scale.view(-1, 1) layer.weight = Parameter(weight.t(), requires_grad=False) # required by torch.compile to be torch.nn.Parameter @@ -174,12 +176,10 @@ class QuarkW8A8Fp8(QuarkScheme): layer.register_parameter("input_scale_ub", None) - weight_quant_strategy = QUANT_STRATEGY_MAP[self.weight_qscheme] - self.fp8_linear_kernel = init_fp8_linear_kernel( - act_q_static=self.is_static_input_scheme, - act_q_group_shape=self.act_quant_group_shape, - weight_quant_strategy=weight_quant_strategy, - out_dtype=self.out_dtype, + self.fp8_linear = init_fp8_linear_kernel( + activation_quant_key=self.activation_quant_key, + weight_quant_key=self.weight_quant_key, + out_dtype=torch.get_default_dtype(), module_name=self.__class__.__name__, ) @@ -189,4 +189,4 @@ class QuarkW8A8Fp8(QuarkScheme): x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - return self.fp8_linear_kernel.apply_weights(layer, x, bias) + return self.fp8_linear.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index d056d3404385a..2c8a614c9e714 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -109,6 +109,9 @@ kFp8StaticTensorSym = QuantKey(FP8_DTYPE, kStaticTensorScale, symmetric=True) kDynamicTensorScale = ScaleDesc(torch.float32, False, GroupShape.PER_TENSOR) kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, kDynamicTensorScale, symmetric=True) +kStaticTokenScale = ScaleDesc(torch.float32, True, GroupShape.PER_TOKEN) +kFp8StaticTokenSym = QuantKey(FP8_DTYPE, kStaticTokenScale, symmetric=True) + kDynamicTokenScale = ScaleDesc(torch.float32, False, GroupShape.PER_TOKEN) kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True)