prefer QuantKey over ScaledMMLinearQuantStrategy

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm 2025-11-04 12:11:13 +00:00
parent a8010c7b1c
commit f5e6cd9695
18 changed files with 147 additions and 187 deletions

View File

@ -23,10 +23,9 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel, init_fp8_linear_kernel,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 from vllm.model_executor.layers.quantization.utils.quant_utils import (
ScaledMMLinearQuantStrategy, kFp8StaticTensorSym,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -37,6 +36,8 @@ FP8_DTYPE = current_platform.fp8_dtype()
class TestSiluMul(torch.nn.Module): class TestSiluMul(torch.nn.Module):
quant_key = kFp8StaticTensorSym
def __init__(self, hidden_size: int = 128): def __init__(self, hidden_size: int = 128):
super().__init__() super().__init__()
self.silu_and_mul = SiluAndMul() self.silu_and_mul = SiluAndMul()
@ -46,9 +47,8 @@ class TestSiluMul(torch.nn.Module):
if TEST_FP8: if TEST_FP8:
self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
self.fp8_linear = init_fp8_linear_kernel( self.fp8_linear = init_fp8_linear_kernel(
act_q_static=True, activation_quant_key=self.quant_key,
act_q_group_shape=GroupShape.PER_TENSOR, weight_quant_key=self.quant_key,
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
out_dtype=torch.get_default_dtype(), out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__, module_name=self.__class__.__name__,
) )
@ -74,6 +74,8 @@ class TestSiluMul(torch.nn.Module):
class TestFusedAddRMSNorm(torch.nn.Module): class TestFusedAddRMSNorm(torch.nn.Module):
quant_key = kFp8StaticTensorSym
def __init__(self, hidden_size=16, intermediate_size=32): def __init__(self, hidden_size=16, intermediate_size=32):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@ -89,9 +91,8 @@ class TestFusedAddRMSNorm(torch.nn.Module):
if TEST_FP8: if TEST_FP8:
self.fp8_linear = init_fp8_linear_kernel( self.fp8_linear = init_fp8_linear_kernel(
act_q_static=True, activation_quant_key=self.quant_key,
act_q_group_shape=GroupShape.PER_TENSOR, weight_quant_key=self.quant_key,
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
out_dtype=torch.get_default_dtype(), out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__, module_name=self.__class__.__name__,
) )

View File

@ -21,9 +21,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel, init_fp8_linear_kernel,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, GroupShape,
QuantKey, QuantKey,
@ -59,10 +56,16 @@ class TestModel(torch.nn.Module):
self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)] self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
weight_quant_strategy = ScaledMMLinearQuantStrategy.TENSOR
quant_scale = ScaleDesc(torch.float32, static, group_shape) act_quant_scale = ScaleDesc(torch.float32, static, group_shape)
self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True) w_quant_scale = ScaleDesc(torch.float32, True, group_shape)
self.activation_quant_key = QuantKey(
dtype=FP8_DTYPE, scale=act_quant_scale, symmetric=True
)
self.weight_quant_key = QuantKey(
dtype=FP8_DTYPE, scale=w_quant_scale, symmetric=True
)
if static: if static:
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
else: else:
@ -74,9 +77,8 @@ class TestModel(torch.nn.Module):
with override_cutlass_fp8_supported(not cuda_force_torch): with override_cutlass_fp8_supported(not cuda_force_torch):
self.fp8_linear = init_fp8_linear_kernel( self.fp8_linear = init_fp8_linear_kernel(
act_q_static=static, activation_quant_key=self.activation_quant_key,
act_q_group_shape=group_shape, weight_quant_key=self.weight_quant_key,
weight_quant_strategy=weight_quant_strategy,
out_dtype=torch.get_default_dtype(), out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__, module_name=self.__class__.__name__,
) )
@ -110,13 +112,13 @@ class TestModel(torch.nn.Module):
def ops_in_model_after(self): def ops_in_model_after(self):
return [ return [
FUSED_OPS[FusedRMSQuantKey(self.quant_key, True)], FUSED_OPS[FusedRMSQuantKey(self.activation_quant_key, True)],
FUSED_OPS[FusedRMSQuantKey(self.quant_key, False)], FUSED_OPS[FusedRMSQuantKey(self.activation_quant_key, False)],
] ]
def ops_in_model_before(self): def ops_in_model_before(self):
return ( return (
[QUANT_OPS[self.quant_key]] [QUANT_OPS[self.activation_quant_key]]
if self.enable_quant_fp8_custom_op if self.enable_quant_fp8_custom_op
else [torch.ops.aten.reciprocal] else [torch.ops.aten.reciprocal]
) )

View File

@ -29,11 +29,8 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel, init_fp8_linear_kernel,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 from vllm.model_executor.layers.quantization.utils.quant_utils import (
ScaledMMLinearQuantStrategy, kFp8StaticTensorSym,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
GroupShape,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.system_utils import update_environment_variables from vllm.utils.system_utils import update_environment_variables
@ -80,6 +77,8 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
quant_key = kFp8StaticTensorSym
def __init__(self, hidden_size=16, token_num=16, eps=1e-6): def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@ -95,9 +94,8 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
] ]
self.fp8_linear = init_fp8_linear_kernel( self.fp8_linear = init_fp8_linear_kernel(
act_q_static=True, activation_quant_key=self.quant_key,
act_q_group_shape=GroupShape.PER_TENSOR, weight_quant_key=self.quant_key,
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
out_dtype=torch.get_default_dtype(), out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__, module_name=self.__class__.__name__,
) )

View File

@ -31,9 +31,6 @@ from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel, init_fp8_linear_kernel,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, QuantKey,
kFp8StaticTensorSym, kFp8StaticTensorSym,
@ -177,18 +174,13 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
if self.quant_key.scale.group_shape.is_per_tensor():
weight_quant_strategy = ScaledMMLinearQuantStrategy.TENSOR
else:
weight_quant_strategy = ScaledMMLinearQuantStrategy.CHANNEL
self.fp8_linear = init_fp8_linear_kernel( self.fp8_linear = init_fp8_linear_kernel(
act_q_static=self.quant_key.scale.static, activation_quant_key=self.quant_key,
act_q_group_shape=self.quant_key.scale.group_shape, weight_quant_key=self.quant_key,
weight_quant_strategy=weight_quant_strategy,
out_dtype=torch.get_default_dtype(), out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__, module_name=self.__class__.__name__,
) )
hidden_size = self.num_qo_heads * self.head_size hidden_size = self.num_qo_heads * self.head_size
self.w = kwargs.get( self.w = kwargs.get(
"w", "w",

View File

@ -30,10 +30,9 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel, init_fp8_linear_kernel,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 from vllm.model_executor.layers.quantization.utils.quant_utils import (
ScaledMMLinearQuantStrategy, kFp8StaticTensorSym,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.system_utils import update_environment_variables from vllm.utils.system_utils import update_environment_variables
@ -101,6 +100,8 @@ class TestModel(torch.nn.Module):
class TestQuantModel(torch.nn.Module): class TestQuantModel(torch.nn.Module):
quant_key = kFp8StaticTensorSym
def __init__(self, hidden_size=16, intermediate_size=32): def __init__(self, hidden_size=16, intermediate_size=32):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@ -114,9 +115,8 @@ class TestQuantModel(torch.nn.Module):
torch.nn.init.normal_(self.gate_proj, std=0.02) torch.nn.init.normal_(self.gate_proj, std=0.02)
self.fp8_linear = init_fp8_linear_kernel( self.fp8_linear = init_fp8_linear_kernel(
act_q_static=True, activation_quant_key=self.quant_key,
act_q_group_shape=GroupShape.PER_TENSOR, weight_quant_key=self.quant_key,
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
out_dtype=torch.get_default_dtype(), out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__, module_name=self.__class__.__name__,
) )

View File

@ -27,11 +27,7 @@ from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel, init_fp8_linear_kernel,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
kFp8StaticTensorSym, kFp8StaticTensorSym,
kNvfp4Quant, kNvfp4Quant,
) )
@ -52,6 +48,8 @@ def is_nvfp4_supported():
class TestSiluMulFp8QuantModel(torch.nn.Module): class TestSiluMulFp8QuantModel(torch.nn.Module):
quant_key = kFp8StaticTensorSym
def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs): def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs):
super().__init__() super().__init__()
self.silu_and_mul = SiluAndMul() self.silu_and_mul = SiluAndMul()
@ -62,13 +60,11 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
with override_cutlass_fp8_supported(not cuda_force_torch): with override_cutlass_fp8_supported(not cuda_force_torch):
self.fp8_linear = init_fp8_linear_kernel( self.fp8_linear = init_fp8_linear_kernel(
act_q_static=True, activation_quant_key=self.quant_key,
act_q_group_shape=GroupShape.PER_TENSOR, weight_quant_key=self.quant_key,
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
out_dtype=torch.get_default_dtype(), out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__, module_name=self.__class__.__name__,
) )
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled() self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()

View File

@ -14,9 +14,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel, init_fp8_linear_kernel,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
QUANT_STRATEGY_MAP,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp, W8A8BlockFp8LinearOp,
check_aiter_fp8_linear_support, check_aiter_fp8_linear_support,
@ -29,7 +26,11 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
process_fp8_weight_tensor_strategy, process_fp8_weight_tensor_strategy,
validate_fp8_block_shape, validate_fp8_block_shape,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_block_fp8_supported, cutlass_block_fp8_supported,
maybe_create_device_identity, maybe_create_device_identity,
@ -48,6 +49,12 @@ strategy_to_parameter_type = {
QuantizationStrategy.TENSOR: PerTensorScaleParameter, QuantizationStrategy.TENSOR: PerTensorScaleParameter,
} }
STATIC_QUANT = True
DYNAMIC_QUANT = False
quant_keys = {
STATIC_QUANT: (kFp8StaticTensorSym, kFp8StaticTensorSym),
DYNAMIC_QUANT: (kFp8DynamicTokenSym, kFp8StaticTensorSym),
}
logger = init_logger(__name__) logger = init_logger(__name__)
@ -57,22 +64,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
self.strategy = weight_quant.strategy self.strategy = weight_quant.strategy
self.out_dtype = torch.get_default_dtype() self.out_dtype = torch.get_default_dtype()
self.is_static_input_scheme = is_static_input_scheme self.is_static_input_scheme = is_static_input_scheme
self.weight_block_size = self.weight_quant.block_structure self.weight_block_size = self.weight_quant.block_structure
if self.weight_block_size is not None:
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
else:
self.act_q_group_shape = (
GroupShape.PER_TENSOR
if is_static_input_scheme
else GroupShape.PER_TOKEN
)
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
if self.weight_block_size is not None: if self.weight_block_size is not None:
assert not self.is_static_input_scheme assert not self.is_static_input_scheme
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(*self.weight_block_size), weight_group_shape=GroupShape(*self.weight_block_size),
act_quant_group_shape=self.act_q_group_shape, act_quant_group_shape=self.act_q_group_shape,
@ -80,12 +78,11 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
use_aiter_and_is_supported=self.use_aiter_and_is_supported, use_aiter_and_is_supported=self.use_aiter_and_is_supported,
) )
else: else:
weight_quant_strategy = QUANT_STRATEGY_MAP[self.strategy] activation_quant_key, weight_quant_key = quant_keys[is_static_input_scheme]
self.fp8_linear_kernel = init_fp8_linear_kernel( self.fp8_linear = init_fp8_linear_kernel(
act_q_static=self.is_static_input_scheme, activation_quant_key=activation_quant_key,
act_q_group_shape=self.act_q_group_shape, weight_quant_key=weight_quant_key,
weight_quant_strategy=weight_quant_strategy, out_dtype=torch.get_default_dtype(),
out_dtype=self.out_dtype,
module_name=self.__class__.__name__, module_name=self.__class__.__name__,
) )
@ -204,4 +201,4 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
bias=bias, bias=bias,
) )
return self.fp8_linear_kernel.apply_weights(layer, x, bias) return self.fp8_linear.apply_weights(layer, x, bias)

View File

@ -21,16 +21,13 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel, init_fp8_linear_kernel,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, apply_fp8_marlin_linear,
prepare_fp8_layer_for_marlin, prepare_fp8_layer_for_marlin,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
is_layer_skipped, is_layer_skipped,
kFp8DynamicTokenSym,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
maybe_create_device_identity, maybe_create_device_identity,
@ -97,12 +94,10 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: FBGEMMFp8Config): def __init__(self, quant_config: FBGEMMFp8Config):
self.quant_config = quant_config self.quant_config = quant_config
self.out_dtype = torch.get_default_dtype() self.out_dtype = torch.get_default_dtype()
self.fp8_linear = init_fp8_linear_kernel(
self.fp8_linear_kernel = init_fp8_linear_kernel( activation_quant_key=kFp8DynamicTokenSym,
act_q_static=False, weight_quant_key=kFp8DynamicTokenSym,
act_q_group_shape=GroupShape.PER_TOKEN, out_dtype=torch.get_default_dtype(),
weight_quant_strategy=ScaledMMLinearQuantStrategy.CHANNEL,
out_dtype=self.out_dtype,
module_name=self.__class__.__name__, module_name=self.__class__.__name__,
) )
@ -194,4 +189,4 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
bias=bias, bias=bias,
) )
return self.fp8_linear_kernel.apply_weights(layer, x, bias) return self.fp8_linear.apply_weights(layer, x, bias)

View File

@ -45,10 +45,6 @@ from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel, init_fp8_linear_kernel,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa E501
FP8ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend, FlashinferMoeBackend,
@ -82,6 +78,9 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, GroupShape,
is_layer_skipped, is_layer_skipped,
kFp8DynamicTensorSym,
kFp8StaticTensorSym,
kFp8StaticTokenSym,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, all_close_1d,
@ -380,8 +379,10 @@ class Fp8LinearMethod(LinearMethodBase):
# Use per-token quantization for better perf if dynamic and cutlass # Use per-token quantization for better perf if dynamic and cutlass
if not self.act_q_static and cutlass_fp8_supported(): if not self.act_q_static and cutlass_fp8_supported():
self.act_q_group_shape = GroupShape.PER_TOKEN self.act_q_group_shape = GroupShape.PER_TOKEN
self.activation_quant_key = kFp8StaticTokenSym
else: else:
self.act_q_group_shape = GroupShape.PER_TENSOR self.act_q_group_shape = GroupShape.PER_TENSOR
self.activation_quant_key = kFp8DynamicTensorSym
if self.block_quant: if self.block_quant:
assert not self.act_q_static assert not self.act_q_static
@ -393,11 +394,10 @@ class Fp8LinearMethod(LinearMethodBase):
use_aiter_and_is_supported=self.use_aiter_and_is_supported, use_aiter_and_is_supported=self.use_aiter_and_is_supported,
) )
else: else:
self.fp8_linear_kernel = init_fp8_linear_kernel( self.fp8_linear = init_fp8_linear_kernel(
act_q_static=self.act_q_static, activation_quant_key=self.activation_quant_key,
act_q_group_shape=self.act_q_group_shape, weight_quant_key=kFp8StaticTensorSym,
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, out_dtype=torch.get_default_dtype(),
out_dtype=self.out_dtype,
module_name=self.__class__.__name__, module_name=self.__class__.__name__,
) )
@ -684,7 +684,7 @@ class Fp8LinearMethod(LinearMethodBase):
bias=bias, bias=bias,
) )
return self.fp8_linear_kernel.apply_weights(layer, x, bias) return self.fp8_linear.apply_weights(layer, x, bias)
class Fp8MoEMethod(FusedMoEMethodBase): class Fp8MoEMethod(FusedMoEMethodBase):

View File

@ -4,43 +4,32 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum
from typing import Generic, TypeVar from typing import Generic, TypeVar
import torch import torch
from compressed_tensors.quantization import QuantizationStrategy
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
)
class ScaledMMLinearQuantStrategy(Enum):
TENSOR = "tensor"
CHANNEL = "channel"
BLOCK = "block"
QUANT_STRATEGY_MAP = {
QuantizationStrategy.TENSOR: ScaledMMLinearQuantStrategy.TENSOR,
QuantizationStrategy.CHANNEL: ScaledMMLinearQuantStrategy.CHANNEL,
}
@dataclass @dataclass
class ScaledMMLinearLayerConfig: class ScaledMMLinearLayerConfig:
is_static_input_scheme: bool pass
@dataclass @dataclass
class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
is_static_input_scheme: bool
is_channelwise: bool is_channelwise: bool
input_symmetric: bool input_symmetric: bool
@dataclass @dataclass
class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
weight_quant_strategy: ScaledMMLinearQuantStrategy weight_quant_key: QuantKey
activation_group_shape: GroupShape activation_quant_key: QuantKey
out_dtype: torch.dtype | None out_dtype: torch.dtype | None
@ -103,9 +92,10 @@ class FP8ScaledMMLinearKernel(
def __init__( def __init__(
self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str] self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str]
) -> None: ) -> None:
act_scale_descriptor = c.activation_quant_key.scale
self.quant_fp8 = QuantFP8( self.quant_fp8 = QuantFP8(
static=c.is_static_input_scheme, static=act_scale_descriptor.static,
group_shape=c.activation_group_shape, group_shape=act_scale_descriptor.group_shape,
num_token_padding=self.get_ouput_padding(), num_token_padding=self.get_ouput_padding(),
) )
super().__init__(c, layer_param_names) super().__init__(c, layer_param_names)

View File

@ -32,7 +32,6 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKer
Int8ScaledMMLinearLayerConfig, Int8ScaledMMLinearLayerConfig,
ScaledMMLinearKernel, ScaledMMLinearKernel,
ScaledMMLinearLayerConfig, ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
TritonScaledMMLinearKernel, TritonScaledMMLinearKernel,
@ -40,7 +39,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
XLAScaledMMLinearKernel, XLAScaledMMLinearKernel,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
from vllm.platforms import PlatformEnum, current_platform from vllm.platforms import PlatformEnum, current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
@ -137,16 +136,14 @@ def choose_scaled_mm_linear_kernel(
def init_fp8_linear_kernel( def init_fp8_linear_kernel(
act_q_static: bool, activation_quant_key: QuantKey,
act_q_group_shape: GroupShape, weight_quant_key: QuantKey,
weight_quant_strategy: ScaledMMLinearQuantStrategy,
out_dtype: torch.dtype, out_dtype: torch.dtype,
module_name: str, module_name: str,
) -> FP8ScaledMMLinearKernel: ) -> FP8ScaledMMLinearKernel:
scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig( scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig(
is_static_input_scheme=act_q_static, weight_quant_key=weight_quant_key,
weight_quant_strategy=weight_quant_strategy, activation_quant_key=activation_quant_key,
activation_group_shape=act_q_group_shape,
out_dtype=out_dtype, out_dtype=out_dtype,
) )

View File

@ -9,7 +9,6 @@ from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
from .ScaledMMLinearKernel import ( from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel, FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy,
) )
from .utils import apply_weights_fp8 from .utils import apply_weights_fp8
@ -39,10 +38,10 @@ class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
@classmethod @classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() per_tensor_activation_scales = (
per_tensor_weight_scales = ( c.activation_quant_key.scale.group_shape.is_per_tensor()
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
) )
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
if not current_platform.is_cuda(): if not current_platform.is_cuda():
return ( return (

View File

@ -10,7 +10,6 @@ from vllm.platforms import current_platform
from .ScaledMMLinearKernel import ( from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel, FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy,
) )
from .utils import apply_weights_fp8 from .utils import apply_weights_fp8
@ -143,10 +142,10 @@ class TorchScaledMMLinearKernel(FP8ScaledMMLinearKernel):
class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
@classmethod @classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() per_tensor_activation_scales = (
per_tensor_weight_scales = ( c.activation_quant_key.scale.group_shape.is_per_tensor()
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
) )
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
if not (per_tensor_activation_scales and per_tensor_weight_scales): if not (per_tensor_activation_scales and per_tensor_weight_scales):
return ( return (
@ -183,10 +182,10 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
@classmethod @classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() per_tensor_activation_scales = (
per_tensor_weight_scales = ( c.activation_quant_key.scale.group_shape.is_per_tensor()
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
) )
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
if per_tensor_activation_scales or per_tensor_weight_scales: if per_tensor_activation_scales or per_tensor_weight_scales:
return ( return (
@ -237,10 +236,10 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
@classmethod @classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() per_tensor_activation_scales = (
per_tensor_weight_scales = ( c.activation_quant_key.scale.group_shape.is_per_tensor()
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
) )
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
if per_tensor_activation_scales and per_tensor_weight_scales: if per_tensor_activation_scales and per_tensor_weight_scales:
return ( return (

View File

@ -11,7 +11,6 @@ from vllm.utils.torch_utils import direct_register_custom_op
from .ScaledMMLinearKernel import ( from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel, FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy,
) )
from .utils import apply_weights_fp8 from .utils import apply_weights_fp8
@ -72,18 +71,18 @@ def rocm_per_tensor_float_w8a8_scaled_mm(
bias: torch.Tensor, bias: torch.Tensor,
output_shape: list[int], output_shape: list[int],
) -> torch.Tensor: ) -> torch.Tensor:
output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl( output = torch.ops.vllm.rocm_per_tensor_float_w8a8_scaled_mm_impl(
A, B, out_dtype, As, Bs, bias A, B, out_dtype, As, Bs, bias
) )
return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape) return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape)
if current_platform.is_rocm(): # if current_platform.is_rocm():
direct_register_custom_op( direct_register_custom_op(
op_name="rocm_per_tensor_float_w8a8_scaled_mm_impl", op_name="rocm_per_tensor_float_w8a8_scaled_mm_impl",
op_func=rocm_per_tensor_float_w8a8_scaled_mm_impl, op_func=rocm_per_tensor_float_w8a8_scaled_mm_impl,
fake_impl=rocm_per_tensor_float_w8a8_scaled_mm_fake, fake_impl=rocm_per_tensor_float_w8a8_scaled_mm_fake,
) )
class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel): class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):
@ -95,10 +94,10 @@ class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):
# TODO: check if this causes an issue on non-ROCM platforms # TODO: check if this causes an issue on non-ROCM platforms
from vllm.platforms.rocm import on_mi3xx from vllm.platforms.rocm import on_mi3xx
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() per_tensor_activation_scales = (
per_tensor_weight_scales = ( c.activation_quant_key.scale.group_shape.is_per_tensor()
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
) )
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
if not current_platform.is_rocm(): if not current_platform.is_rocm():
return ( return (

View File

@ -40,9 +40,6 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel, init_fp8_linear_kernel,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_prepare_finalize, build_flashinfer_fp4_cutlass_moe_prepare_finalize,
@ -68,9 +65,9 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_moe_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
cutlass_fp4_supported, cutlass_fp4_supported,
is_layer_skipped, is_layer_skipped,
kFp8StaticTensorSym,
swizzle_blockscale, swizzle_blockscale,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
@ -260,10 +257,9 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: ModelOptFp8Config) -> None: def __init__(self, quant_config: ModelOptFp8Config) -> None:
self.quant_config = quant_config self.quant_config = quant_config
self.fp8_linear = init_fp8_linear_kernel( self.fp8_linear = init_fp8_linear_kernel(
act_q_static=True, activation_quant_key=kFp8StaticTensorSym,
act_q_group_shape=GroupShape.PER_TENSOR, weight_quant_key=kFp8StaticTensorSym,
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, out_dtype=torch.get_default_dtype(),
out_dtype=None,
module_name=self.__class__.__name__, module_name=self.__class__.__name__,
) )

View File

@ -19,12 +19,9 @@ from vllm.model_executor.layers.quantization.fp8 import (
from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel, init_fp8_linear_kernel,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa E501
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
is_layer_skipped, is_layer_skipped,
kFp8DynamicTokenSym,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -103,11 +100,10 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
) )
super().__init__(quant_config=quant_config) super().__init__(quant_config=quant_config)
# Force weight quantization # Force weight quantization
self.fp8_linear_kernel = init_fp8_linear_kernel( self.fp8_linear = init_fp8_linear_kernel(
act_q_static=False, activation_quant_key=kFp8DynamicTokenSym,
act_q_group_shape=GroupShape.PER_TOKEN, weight_quant_key=kFp8DynamicTokenSym,
weight_quant_strategy=ScaledMMLinearQuantStrategy.CHANNEL, out_dtype=torch.get_default_dtype(),
out_dtype=self.out_dtype,
module_name=self.__class__.__name__, module_name=self.__class__.__name__,
) )
@ -135,4 +131,4 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
return self.fp8_linear_kernel.apply_weights(layer, x, bias) return self.fp8_linear.apply_weights(layer, x, bias)

View File

@ -11,11 +11,13 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel, init_fp8_linear_kernel,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
kFp8StaticTokenSym,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_e4m3fnuz,
requantize_with_max_scale, requantize_with_max_scale,
@ -31,11 +33,6 @@ __all__ = ["QuarkW8A8Fp8"]
logger = init_logger(__name__) logger = init_logger(__name__)
QUANT_STRATEGY_MAP = {
"per_tensor": ScaledMMLinearQuantStrategy.TENSOR,
"per_channel": ScaledMMLinearQuantStrategy.CHANNEL,
}
class QuarkW8A8Fp8(QuarkScheme): class QuarkW8A8Fp8(QuarkScheme):
def __init__( def __init__(
@ -48,11 +45,16 @@ class QuarkW8A8Fp8(QuarkScheme):
self.is_static_input_scheme = not cast(bool, input_config.get("is_dynamic")) self.is_static_input_scheme = not cast(bool, input_config.get("is_dynamic"))
self.input_qscheme = cast(str, input_config.get("qscheme")) self.input_qscheme = cast(str, input_config.get("qscheme"))
per_token = ( per_token_activation = (
not self.is_static_input_scheme and self.input_qscheme == "per_channel" not self.is_static_input_scheme and self.input_qscheme == "per_channel"
) )
self.act_quant_group_shape = ( per_token_weight = self.weight_qscheme == "per_channel"
GroupShape.PER_TOKEN if per_token else GroupShape.PER_TENSOR
self.activation_quant_key = (
kFp8DynamicTokenSym if per_token_activation else kFp8StaticTensorSym
)
self.weight_quant_key = (
kFp8StaticTokenSym if per_token_weight else kFp8StaticTensorSym
) )
self.out_dtype = torch.get_default_dtype() self.out_dtype = torch.get_default_dtype()
@ -103,7 +105,7 @@ class QuarkW8A8Fp8(QuarkScheme):
layer.input_scale = Parameter(input_scale, requires_grad=False) layer.input_scale = Parameter(input_scale, requires_grad=False)
else: else:
weight_scale = layer.weight_scale.data weight_scale = layer.weight_scale.data
if self.act_quant_group_shape == GroupShape.PER_TOKEN: if self.activation_quant_key.scale.group_shape == GroupShape.PER_TOKEN:
weight_scale = weight_scale.view(-1, 1) weight_scale = weight_scale.view(-1, 1)
layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight = Parameter(weight.t(), requires_grad=False)
# required by torch.compile to be torch.nn.Parameter # required by torch.compile to be torch.nn.Parameter
@ -174,12 +176,10 @@ class QuarkW8A8Fp8(QuarkScheme):
layer.register_parameter("input_scale_ub", None) layer.register_parameter("input_scale_ub", None)
weight_quant_strategy = QUANT_STRATEGY_MAP[self.weight_qscheme] self.fp8_linear = init_fp8_linear_kernel(
self.fp8_linear_kernel = init_fp8_linear_kernel( activation_quant_key=self.activation_quant_key,
act_q_static=self.is_static_input_scheme, weight_quant_key=self.weight_quant_key,
act_q_group_shape=self.act_quant_group_shape, out_dtype=torch.get_default_dtype(),
weight_quant_strategy=weight_quant_strategy,
out_dtype=self.out_dtype,
module_name=self.__class__.__name__, module_name=self.__class__.__name__,
) )
@ -189,4 +189,4 @@ class QuarkW8A8Fp8(QuarkScheme):
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
return self.fp8_linear_kernel.apply_weights(layer, x, bias) return self.fp8_linear.apply_weights(layer, x, bias)

View File

@ -109,6 +109,9 @@ kFp8StaticTensorSym = QuantKey(FP8_DTYPE, kStaticTensorScale, symmetric=True)
kDynamicTensorScale = ScaleDesc(torch.float32, False, GroupShape.PER_TENSOR) kDynamicTensorScale = ScaleDesc(torch.float32, False, GroupShape.PER_TENSOR)
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, kDynamicTensorScale, symmetric=True) kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, kDynamicTensorScale, symmetric=True)
kStaticTokenScale = ScaleDesc(torch.float32, True, GroupShape.PER_TOKEN)
kFp8StaticTokenSym = QuantKey(FP8_DTYPE, kStaticTokenScale, symmetric=True)
kDynamicTokenScale = ScaleDesc(torch.float32, False, GroupShape.PER_TOKEN) kDynamicTokenScale = ScaleDesc(torch.float32, False, GroupShape.PER_TOKEN)
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True) kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True)