prefer QuantKey over ScaledMMLinearQuantStrategy

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm 2025-11-04 12:11:13 +00:00
parent a8010c7b1c
commit f5e6cd9695
18 changed files with 147 additions and 187 deletions

View File

@ -23,10 +23,9 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearQuantStrategy,
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform
@ -37,6 +36,8 @@ FP8_DTYPE = current_platform.fp8_dtype()
class TestSiluMul(torch.nn.Module):
quant_key = kFp8StaticTensorSym
def __init__(self, hidden_size: int = 128):
super().__init__()
self.silu_and_mul = SiluAndMul()
@ -46,9 +47,8 @@ class TestSiluMul(torch.nn.Module):
if TEST_FP8:
self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
self.fp8_linear = init_fp8_linear_kernel(
act_q_static=True,
act_q_group_shape=GroupShape.PER_TENSOR,
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
@ -74,6 +74,8 @@ class TestSiluMul(torch.nn.Module):
class TestFusedAddRMSNorm(torch.nn.Module):
quant_key = kFp8StaticTensorSym
def __init__(self, hidden_size=16, intermediate_size=32):
super().__init__()
self.hidden_size = hidden_size
@ -89,9 +91,8 @@ class TestFusedAddRMSNorm(torch.nn.Module):
if TEST_FP8:
self.fp8_linear = init_fp8_linear_kernel(
act_q_static=True,
act_q_group_shape=GroupShape.PER_TENSOR,
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)

View File

@ -21,9 +21,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
QuantKey,
@ -59,10 +56,16 @@ class TestModel(torch.nn.Module):
self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
weight_quant_strategy = ScaledMMLinearQuantStrategy.TENSOR
quant_scale = ScaleDesc(torch.float32, static, group_shape)
self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
act_quant_scale = ScaleDesc(torch.float32, static, group_shape)
w_quant_scale = ScaleDesc(torch.float32, True, group_shape)
self.activation_quant_key = QuantKey(
dtype=FP8_DTYPE, scale=act_quant_scale, symmetric=True
)
self.weight_quant_key = QuantKey(
dtype=FP8_DTYPE, scale=w_quant_scale, symmetric=True
)
if static:
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
else:
@ -74,9 +77,8 @@ class TestModel(torch.nn.Module):
with override_cutlass_fp8_supported(not cuda_force_torch):
self.fp8_linear = init_fp8_linear_kernel(
act_q_static=static,
act_q_group_shape=group_shape,
weight_quant_strategy=weight_quant_strategy,
activation_quant_key=self.activation_quant_key,
weight_quant_key=self.weight_quant_key,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
@ -110,13 +112,13 @@ class TestModel(torch.nn.Module):
def ops_in_model_after(self):
return [
FUSED_OPS[FusedRMSQuantKey(self.quant_key, True)],
FUSED_OPS[FusedRMSQuantKey(self.quant_key, False)],
FUSED_OPS[FusedRMSQuantKey(self.activation_quant_key, True)],
FUSED_OPS[FusedRMSQuantKey(self.activation_quant_key, False)],
]
def ops_in_model_before(self):
return (
[QUANT_OPS[self.quant_key]]
[QUANT_OPS[self.activation_quant_key]]
if self.enable_quant_fp8_custom_op
else [torch.ops.aten.reciprocal]
)

View File

@ -29,11 +29,8 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
GroupShape,
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
from vllm.utils.system_utils import update_environment_variables
@ -80,6 +77,8 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
quant_key = kFp8StaticTensorSym
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
@ -95,9 +94,8 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
]
self.fp8_linear = init_fp8_linear_kernel(
act_q_static=True,
act_q_group_shape=GroupShape.PER_TENSOR,
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)

View File

@ -31,9 +31,6 @@ from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
@ -177,18 +174,13 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.quant_key.scale.group_shape.is_per_tensor():
weight_quant_strategy = ScaledMMLinearQuantStrategy.TENSOR
else:
weight_quant_strategy = ScaledMMLinearQuantStrategy.CHANNEL
self.fp8_linear = init_fp8_linear_kernel(
act_q_static=self.quant_key.scale.static,
act_q_group_shape=self.quant_key.scale.group_shape,
weight_quant_strategy=weight_quant_strategy,
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
hidden_size = self.num_qo_heads * self.head_size
self.w = kwargs.get(
"w",

View File

@ -30,10 +30,9 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearQuantStrategy,
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from vllm.utils.system_utils import update_environment_variables
@ -101,6 +100,8 @@ class TestModel(torch.nn.Module):
class TestQuantModel(torch.nn.Module):
quant_key = kFp8StaticTensorSym
def __init__(self, hidden_size=16, intermediate_size=32):
super().__init__()
self.hidden_size = hidden_size
@ -114,9 +115,8 @@ class TestQuantModel(torch.nn.Module):
torch.nn.init.normal_(self.gate_proj, std=0.02)
self.fp8_linear = init_fp8_linear_kernel(
act_q_static=True,
act_q_group_shape=GroupShape.PER_TENSOR,
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)

View File

@ -27,11 +27,7 @@ from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
kFp8StaticTensorSym,
kNvfp4Quant,
)
@ -52,6 +48,8 @@ def is_nvfp4_supported():
class TestSiluMulFp8QuantModel(torch.nn.Module):
quant_key = kFp8StaticTensorSym
def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs):
super().__init__()
self.silu_and_mul = SiluAndMul()
@ -62,13 +60,11 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
with override_cutlass_fp8_supported(not cuda_force_torch):
self.fp8_linear = init_fp8_linear_kernel(
act_q_static=True,
act_q_group_shape=GroupShape.PER_TENSOR,
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()

View File

@ -14,9 +14,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
QUANT_STRATEGY_MAP,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
check_aiter_fp8_linear_support,
@ -29,7 +26,11 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
process_fp8_weight_tensor_strategy,
validate_fp8_block_shape,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_block_fp8_supported,
maybe_create_device_identity,
@ -48,6 +49,12 @@ strategy_to_parameter_type = {
QuantizationStrategy.TENSOR: PerTensorScaleParameter,
}
STATIC_QUANT = True
DYNAMIC_QUANT = False
quant_keys = {
STATIC_QUANT: (kFp8StaticTensorSym, kFp8StaticTensorSym),
DYNAMIC_QUANT: (kFp8DynamicTokenSym, kFp8StaticTensorSym),
}
logger = init_logger(__name__)
@ -57,22 +64,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
self.strategy = weight_quant.strategy
self.out_dtype = torch.get_default_dtype()
self.is_static_input_scheme = is_static_input_scheme
self.weight_block_size = self.weight_quant.block_structure
if self.weight_block_size is not None:
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
else:
self.act_q_group_shape = (
GroupShape.PER_TENSOR
if is_static_input_scheme
else GroupShape.PER_TOKEN
)
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
if self.weight_block_size is not None:
assert not self.is_static_input_scheme
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(*self.weight_block_size),
act_quant_group_shape=self.act_q_group_shape,
@ -80,12 +78,11 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
)
else:
weight_quant_strategy = QUANT_STRATEGY_MAP[self.strategy]
self.fp8_linear_kernel = init_fp8_linear_kernel(
act_q_static=self.is_static_input_scheme,
act_q_group_shape=self.act_q_group_shape,
weight_quant_strategy=weight_quant_strategy,
out_dtype=self.out_dtype,
activation_quant_key, weight_quant_key = quant_keys[is_static_input_scheme]
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=activation_quant_key,
weight_quant_key=weight_quant_key,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
@ -204,4 +201,4 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
bias=bias,
)
return self.fp8_linear_kernel.apply_weights(layer, x, bias)
return self.fp8_linear.apply_weights(layer, x, bias)

View File

@ -21,16 +21,13 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear,
prepare_fp8_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
is_layer_skipped,
kFp8DynamicTokenSym,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
maybe_create_device_identity,
@ -97,12 +94,10 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: FBGEMMFp8Config):
self.quant_config = quant_config
self.out_dtype = torch.get_default_dtype()
self.fp8_linear_kernel = init_fp8_linear_kernel(
act_q_static=False,
act_q_group_shape=GroupShape.PER_TOKEN,
weight_quant_strategy=ScaledMMLinearQuantStrategy.CHANNEL,
out_dtype=self.out_dtype,
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=kFp8DynamicTokenSym,
weight_quant_key=kFp8DynamicTokenSym,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
@ -194,4 +189,4 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
bias=bias,
)
return self.fp8_linear_kernel.apply_weights(layer, x, bias)
return self.fp8_linear.apply_weights(layer, x, bias)

View File

@ -45,10 +45,6 @@ from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa E501
FP8ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
@ -82,6 +78,9 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
is_layer_skipped,
kFp8DynamicTensorSym,
kFp8StaticTensorSym,
kFp8StaticTokenSym,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d,
@ -380,8 +379,10 @@ class Fp8LinearMethod(LinearMethodBase):
# Use per-token quantization for better perf if dynamic and cutlass
if not self.act_q_static and cutlass_fp8_supported():
self.act_q_group_shape = GroupShape.PER_TOKEN
self.activation_quant_key = kFp8StaticTokenSym
else:
self.act_q_group_shape = GroupShape.PER_TENSOR
self.activation_quant_key = kFp8DynamicTensorSym
if self.block_quant:
assert not self.act_q_static
@ -393,11 +394,10 @@ class Fp8LinearMethod(LinearMethodBase):
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
)
else:
self.fp8_linear_kernel = init_fp8_linear_kernel(
act_q_static=self.act_q_static,
act_q_group_shape=self.act_q_group_shape,
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
out_dtype=self.out_dtype,
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=self.activation_quant_key,
weight_quant_key=kFp8StaticTensorSym,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
@ -684,7 +684,7 @@ class Fp8LinearMethod(LinearMethodBase):
bias=bias,
)
return self.fp8_linear_kernel.apply_weights(layer, x, bias)
return self.fp8_linear.apply_weights(layer, x, bias)
class Fp8MoEMethod(FusedMoEMethodBase):

View File

@ -4,43 +4,32 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum
from typing import Generic, TypeVar
import torch
from compressed_tensors.quantization import QuantizationStrategy
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
class ScaledMMLinearQuantStrategy(Enum):
TENSOR = "tensor"
CHANNEL = "channel"
BLOCK = "block"
QUANT_STRATEGY_MAP = {
QuantizationStrategy.TENSOR: ScaledMMLinearQuantStrategy.TENSOR,
QuantizationStrategy.CHANNEL: ScaledMMLinearQuantStrategy.CHANNEL,
}
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
)
@dataclass
class ScaledMMLinearLayerConfig:
is_static_input_scheme: bool
pass
@dataclass
class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
is_static_input_scheme: bool
is_channelwise: bool
input_symmetric: bool
@dataclass
class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
weight_quant_strategy: ScaledMMLinearQuantStrategy
activation_group_shape: GroupShape
weight_quant_key: QuantKey
activation_quant_key: QuantKey
out_dtype: torch.dtype | None
@ -103,9 +92,10 @@ class FP8ScaledMMLinearKernel(
def __init__(
self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str]
) -> None:
act_scale_descriptor = c.activation_quant_key.scale
self.quant_fp8 = QuantFP8(
static=c.is_static_input_scheme,
group_shape=c.activation_group_shape,
static=act_scale_descriptor.static,
group_shape=act_scale_descriptor.group_shape,
num_token_padding=self.get_ouput_padding(),
)
super().__init__(c, layer_param_names)

View File

@ -32,7 +32,6 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKer
Int8ScaledMMLinearLayerConfig,
ScaledMMLinearKernel,
ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
TritonScaledMMLinearKernel,
@ -40,7 +39,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
XLAScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
from vllm.platforms import PlatformEnum, current_platform
logger = init_logger(__name__)
@ -137,16 +136,14 @@ def choose_scaled_mm_linear_kernel(
def init_fp8_linear_kernel(
act_q_static: bool,
act_q_group_shape: GroupShape,
weight_quant_strategy: ScaledMMLinearQuantStrategy,
activation_quant_key: QuantKey,
weight_quant_key: QuantKey,
out_dtype: torch.dtype,
module_name: str,
) -> FP8ScaledMMLinearKernel:
scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig(
is_static_input_scheme=act_q_static,
weight_quant_strategy=weight_quant_strategy,
activation_group_shape=act_q_group_shape,
weight_quant_key=weight_quant_key,
activation_quant_key=activation_quant_key,
out_dtype=out_dtype,
)

View File

@ -9,7 +9,6 @@ from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy,
)
from .utils import apply_weights_fp8
@ -39,10 +38,10 @@ class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
per_tensor_weight_scales = (
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
per_tensor_activation_scales = (
c.activation_quant_key.scale.group_shape.is_per_tensor()
)
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
if not current_platform.is_cuda():
return (

View File

@ -10,7 +10,6 @@ from vllm.platforms import current_platform
from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy,
)
from .utils import apply_weights_fp8
@ -143,10 +142,10 @@ class TorchScaledMMLinearKernel(FP8ScaledMMLinearKernel):
class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
per_tensor_weight_scales = (
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
per_tensor_activation_scales = (
c.activation_quant_key.scale.group_shape.is_per_tensor()
)
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
if not (per_tensor_activation_scales and per_tensor_weight_scales):
return (
@ -183,10 +182,10 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
per_tensor_weight_scales = (
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
per_tensor_activation_scales = (
c.activation_quant_key.scale.group_shape.is_per_tensor()
)
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
if per_tensor_activation_scales or per_tensor_weight_scales:
return (
@ -237,10 +236,10 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
per_tensor_weight_scales = (
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
per_tensor_activation_scales = (
c.activation_quant_key.scale.group_shape.is_per_tensor()
)
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
if per_tensor_activation_scales and per_tensor_weight_scales:
return (

View File

@ -11,7 +11,6 @@ from vllm.utils.torch_utils import direct_register_custom_op
from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy,
)
from .utils import apply_weights_fp8
@ -72,18 +71,18 @@ def rocm_per_tensor_float_w8a8_scaled_mm(
bias: torch.Tensor,
output_shape: list[int],
) -> torch.Tensor:
output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl(
output = torch.ops.vllm.rocm_per_tensor_float_w8a8_scaled_mm_impl(
A, B, out_dtype, As, Bs, bias
)
return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape)
if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_per_tensor_float_w8a8_scaled_mm_impl",
op_func=rocm_per_tensor_float_w8a8_scaled_mm_impl,
fake_impl=rocm_per_tensor_float_w8a8_scaled_mm_fake,
)
# if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_per_tensor_float_w8a8_scaled_mm_impl",
op_func=rocm_per_tensor_float_w8a8_scaled_mm_impl,
fake_impl=rocm_per_tensor_float_w8a8_scaled_mm_fake,
)
class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):
@ -95,10 +94,10 @@ class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):
# TODO: check if this causes an issue on non-ROCM platforms
from vllm.platforms.rocm import on_mi3xx
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
per_tensor_weight_scales = (
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
per_tensor_activation_scales = (
c.activation_quant_key.scale.group_shape.is_per_tensor()
)
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
if not current_platform.is_rocm():
return (

View File

@ -40,9 +40,6 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
@ -68,9 +65,9 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_moe_fp4_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
cutlass_fp4_supported,
is_layer_skipped,
kFp8StaticTensorSym,
swizzle_blockscale,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
@ -260,10 +257,9 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: ModelOptFp8Config) -> None:
self.quant_config = quant_config
self.fp8_linear = init_fp8_linear_kernel(
act_q_static=True,
act_q_group_shape=GroupShape.PER_TENSOR,
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
out_dtype=None,
activation_quant_key=kFp8StaticTensorSym,
weight_quant_key=kFp8StaticTensorSym,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)

View File

@ -19,12 +19,9 @@ from vllm.model_executor.layers.quantization.fp8 import (
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa E501
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
is_layer_skipped,
kFp8DynamicTokenSym,
)
from vllm.platforms import current_platform
@ -103,11 +100,10 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
)
super().__init__(quant_config=quant_config)
# Force weight quantization
self.fp8_linear_kernel = init_fp8_linear_kernel(
act_q_static=False,
act_q_group_shape=GroupShape.PER_TOKEN,
weight_quant_strategy=ScaledMMLinearQuantStrategy.CHANNEL,
out_dtype=self.out_dtype,
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=kFp8DynamicTokenSym,
weight_quant_key=kFp8DynamicTokenSym,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
@ -135,4 +131,4 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return self.fp8_linear_kernel.apply_weights(layer, x, bias)
return self.fp8_linear.apply_weights(layer, x, bias)

View File

@ -11,11 +11,13 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
kFp8StaticTokenSym,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
normalize_e4m3fn_to_e4m3fnuz,
requantize_with_max_scale,
@ -31,11 +33,6 @@ __all__ = ["QuarkW8A8Fp8"]
logger = init_logger(__name__)
QUANT_STRATEGY_MAP = {
"per_tensor": ScaledMMLinearQuantStrategy.TENSOR,
"per_channel": ScaledMMLinearQuantStrategy.CHANNEL,
}
class QuarkW8A8Fp8(QuarkScheme):
def __init__(
@ -48,11 +45,16 @@ class QuarkW8A8Fp8(QuarkScheme):
self.is_static_input_scheme = not cast(bool, input_config.get("is_dynamic"))
self.input_qscheme = cast(str, input_config.get("qscheme"))
per_token = (
per_token_activation = (
not self.is_static_input_scheme and self.input_qscheme == "per_channel"
)
self.act_quant_group_shape = (
GroupShape.PER_TOKEN if per_token else GroupShape.PER_TENSOR
per_token_weight = self.weight_qscheme == "per_channel"
self.activation_quant_key = (
kFp8DynamicTokenSym if per_token_activation else kFp8StaticTensorSym
)
self.weight_quant_key = (
kFp8StaticTokenSym if per_token_weight else kFp8StaticTensorSym
)
self.out_dtype = torch.get_default_dtype()
@ -103,7 +105,7 @@ class QuarkW8A8Fp8(QuarkScheme):
layer.input_scale = Parameter(input_scale, requires_grad=False)
else:
weight_scale = layer.weight_scale.data
if self.act_quant_group_shape == GroupShape.PER_TOKEN:
if self.activation_quant_key.scale.group_shape == GroupShape.PER_TOKEN:
weight_scale = weight_scale.view(-1, 1)
layer.weight = Parameter(weight.t(), requires_grad=False)
# required by torch.compile to be torch.nn.Parameter
@ -174,12 +176,10 @@ class QuarkW8A8Fp8(QuarkScheme):
layer.register_parameter("input_scale_ub", None)
weight_quant_strategy = QUANT_STRATEGY_MAP[self.weight_qscheme]
self.fp8_linear_kernel = init_fp8_linear_kernel(
act_q_static=self.is_static_input_scheme,
act_q_group_shape=self.act_quant_group_shape,
weight_quant_strategy=weight_quant_strategy,
out_dtype=self.out_dtype,
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=self.activation_quant_key,
weight_quant_key=self.weight_quant_key,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
@ -189,4 +189,4 @@ class QuarkW8A8Fp8(QuarkScheme):
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return self.fp8_linear_kernel.apply_weights(layer, x, bias)
return self.fp8_linear.apply_weights(layer, x, bias)

View File

@ -109,6 +109,9 @@ kFp8StaticTensorSym = QuantKey(FP8_DTYPE, kStaticTensorScale, symmetric=True)
kDynamicTensorScale = ScaleDesc(torch.float32, False, GroupShape.PER_TENSOR)
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, kDynamicTensorScale, symmetric=True)
kStaticTokenScale = ScaleDesc(torch.float32, True, GroupShape.PER_TOKEN)
kFp8StaticTokenSym = QuantKey(FP8_DTYPE, kStaticTokenScale, symmetric=True)
kDynamicTokenScale = ScaleDesc(torch.float32, False, GroupShape.PER_TOKEN)
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True)