fix quant key selection for ct; remove register_paramter calls; format

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm 2025-11-04 14:12:14 +00:00
parent 93fb7071f5
commit abf597e542
11 changed files with 114 additions and 60 deletions

View File

@ -25,8 +25,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
) )
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform from vllm.platforms import current_platform
from ..utils import TestFP8Layer
from ..utils import TestFP8Layer
from .backend import TestBackend from .backend import TestBackend
TEST_FP8 = current_platform.supports_fp8() TEST_FP8 = current_platform.supports_fp8()
@ -43,8 +43,14 @@ class TestSiluMul(torch.nn.Module):
self.input_scale = torch.rand(1, dtype=torch.float32) self.input_scale = torch.rand(1, dtype=torch.float32)
if TEST_FP8: if TEST_FP8:
self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
self.fp8_linear = TestFP8Layer(self.quant_key, self.quant_key, self.weight, self.fp8_linear = TestFP8Layer(
self.weight_scale, self.input_scale) self.quant_key,
self.quant_key,
self.weight,
self.weight_scale,
self.input_scale,
)
def forward(self, x): def forward(self, x):
y = self.silu_and_mul(x) y = self.silu_and_mul(x)
if TEST_FP8: if TEST_FP8:
@ -87,8 +93,13 @@ class TestFusedAddRMSNorm(torch.nn.Module):
) )
self.weight_scale = torch.rand(1, dtype=torch.float32) self.weight_scale = torch.rand(1, dtype=torch.float32)
self.input_scale = torch.rand(1, dtype=torch.float32) self.input_scale = torch.rand(1, dtype=torch.float32)
self.fp8_linear = TestFP8Layer(self.quant_key, self.quant_key, self.fp8_linear = TestFP8Layer(
self.weight, self.weight_scale, self.input_scale) self.quant_key,
self.quant_key,
self.weight,
self.weight_scale,
self.input_scale,
)
def forward(self, hidden_states, residual): def forward(self, hidden_states, residual):
# Reshape input # Reshape input

View File

@ -18,7 +18,6 @@ from vllm.config import (
VllmConfig, VllmConfig,
) )
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, GroupShape,
QuantKey, QuantKey,
@ -74,12 +73,27 @@ class TestModel(torch.nn.Module):
] ]
with override_cutlass_fp8_supported(not cuda_force_torch): with override_cutlass_fp8_supported(not cuda_force_torch):
self.fp8_linear_1 = TestFP8Layer(self.activation_quant_key, self.weight_quant_key, self.fp8_linear_1 = TestFP8Layer(
self.w[0], self.wscale[0], self.scale[0]) self.activation_quant_key,
self.fp8_linear_2 = TestFP8Layer(self.activation_quant_key, self.weight_quant_key, self.weight_quant_key,
self.w[1], self.wscale[1], self.scale[1]) self.w[0],
self.fp8_linear_3 = TestFP8Layer(self.activation_quant_key, self.weight_quant_key, self.wscale[0],
self.w[2], self.wscale[2], self.scale[2]) self.scale[0],
)
self.fp8_linear_2 = TestFP8Layer(
self.activation_quant_key,
self.weight_quant_key,
self.w[1],
self.wscale[1],
self.scale[1],
)
self.fp8_linear_3 = TestFP8Layer(
self.activation_quant_key,
self.weight_quant_key,
self.w[2],
self.wscale[2],
self.scale[2],
)
self.enable_rms_norm_custom_op = self.norm[0].enabled() self.enable_rms_norm_custom_op = self.norm[0].enabled()
self.enable_quant_fp8_custom_op = self.fp8_linear.is_quant_fp8_enabled() self.enable_quant_fp8_custom_op = self.fp8_linear.is_quant_fp8_enabled()

View File

@ -26,7 +26,6 @@ from vllm.distributed.parallel_state import (
initialize_model_parallel, initialize_model_parallel,
) )
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym, kFp8StaticTensorSym,
) )
@ -91,14 +90,29 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
for _ in range(3) for _ in range(3)
] ]
self.fp8_linear_1 = TestFP8Layer(self.quant_key,self.quant_key, self.fp8_linear_1 = TestFP8Layer(
self.weight[0],self.wscale[0], input_scale=self.input_scale[0]) self.quant_key,
self.quant_key,
self.fp8_linear_2 = TestFP8Layer(self.quant_key,self.quant_key, self.weight[0],
self.weight[1],self.wscale[1], input_scale=self.input_scale[1]) self.wscale[0],
input_scale=self.input_scale[0],
self.fp8_linear_3 = TestFP8Layer(self.quant_key, self.quant_key, )
self.weight[2], self.wscale[2],input_scale=self.input_scale[2])
self.fp8_linear_2 = TestFP8Layer(
self.quant_key,
self.quant_key,
self.weight[1],
self.wscale[1],
input_scale=self.input_scale[1],
)
self.fp8_linear_3 = TestFP8Layer(
self.quant_key,
self.quant_key,
self.weight[2],
self.wscale[2],
input_scale=self.input_scale[2],
)
def forward(self, hidden_states): def forward(self, hidden_states):
# avoid having graph input be an arg to a pattern directly # avoid having graph input be an arg to a pattern directly
@ -106,7 +120,6 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
x = resid = tensor_model_parallel_all_reduce(z) x = resid = tensor_model_parallel_all_reduce(z)
y = self.norm[0](x) y = self.norm[0](x)
z2 = self.fp8_linear_1(y) z2 = self.fp8_linear_1(y)
x2 = tensor_model_parallel_all_reduce(z2) x2 = tensor_model_parallel_all_reduce(z2)

View File

@ -28,7 +28,6 @@ from vllm.config import (
set_current_vllm_config, set_current_vllm_config,
) )
from vllm.forward_context import get_forward_context, set_forward_context from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, QuantKey,
kFp8StaticTensorSym, kFp8StaticTensorSym,
@ -172,7 +171,6 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
hidden_size = self.num_qo_heads * self.head_size hidden_size = self.num_qo_heads * self.head_size
self.w = kwargs.get( self.w = kwargs.get(
"w", "w",
@ -184,8 +182,13 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
"scale": torch.tensor([1.0], dtype=torch.float32, device=self.device), "scale": torch.tensor([1.0], dtype=torch.float32, device=self.device),
}, },
) )
self.fp8_linear = TestFP8Layer(self.quant_key, self.quant_key, self.w["weight"], self.fp8_linear = TestFP8Layer(
self.w["wscale"], self.w["scale"]) self.quant_key,
self.quant_key,
self.w["weight"],
self.w["wscale"],
self.w["scale"],
)
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""Forward pass that creates the pattern to be fused.""" """Forward pass that creates the pattern to be fused."""

View File

@ -27,7 +27,6 @@ from vllm.distributed.parallel_state import (
initialize_model_parallel, initialize_model_parallel,
) )
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym, kFp8StaticTensorSym,
) )
@ -117,10 +116,10 @@ class TestQuantModel(torch.nn.Module):
# which expects a column-major layout. # which expects a column-major layout.
self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t() self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
self.wscale = torch.rand(1, dtype=torch.float32) self.wscale = torch.rand(1, dtype=torch.float32)
self.fp8_linear = TestFP8Layer(self.quant_key, self.quant_key, self.fp8_linear = TestFP8Layer(
self.w, self.wscale, self.scale) self.quant_key, self.quant_key, self.w, self.wscale, self.scale
)
def forward(self, hidden_states, residual): def forward(self, hidden_states, residual):
""" """
Forward pass implementing the operations in the FX graph Forward pass implementing the operations in the FX graph

View File

@ -24,7 +24,6 @@ from vllm.config import (
set_current_vllm_config, set_current_vllm_config,
) )
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym, kFp8StaticTensorSym,
kNvfp4Quant, kNvfp4Quant,
@ -56,10 +55,14 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
with override_cutlass_fp8_supported(not cuda_force_torch): with override_cutlass_fp8_supported(not cuda_force_torch):
self.fp8_linear = TestFP8Layer(self.quant_key, self.quant_key, self.fp8_linear = TestFP8Layer(
self.weight, self.weight_scale, self.input_scale) self.quant_key,
self.quant_key,
self.weight,
self.weight_scale,
self.input_scale,
)
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled() self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
self.enable_quant_fp8_custom_op = self.fp8_linear.is_quant_fp8_enabled() self.enable_quant_fp8_custom_op = self.fp8_linear.is_quant_fp8_enabled()

View File

@ -42,6 +42,10 @@ from vllm.distributed import (
) )
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.cli.serve import ServeSubcommand from vllm.entrypoints.cli.serve import ServeSubcommand
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader import get_model_loader
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
@ -49,8 +53,6 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.mem_constants import GB_bytes from vllm.utils.mem_constants import GB_bytes
from vllm.utils.network_utils import get_open_port from vllm.utils.network_utils import get_open_port
from vllm.utils.torch_utils import cuda_device_count_stateless from vllm.utils.torch_utils import cuda_device_count_stateless
from vllm.model_executor.layers.quantization.kernels.scaled_mm import init_fp8_linear_kernel
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
if current_platform.is_rocm(): if current_platform.is_rocm():
from amdsmi import ( from amdsmi import (
@ -1429,32 +1431,36 @@ class TestFP8Layer(torch.nn.Module):
weight (torch.Tensor): Weight tensor for linear transformation. weight (torch.Tensor): Weight tensor for linear transformation.
weight_scale (torch.Tensor): Per-tensor or per-group scale for weights. weight_scale (torch.Tensor): Per-tensor or per-group scale for weights.
input_scale (torch.Tensor): Scale tensor for input quantization. input_scale (torch.Tensor): Scale tensor for input quantization.
out_dtype (torch.dtype, optional): Output tensor data type. Defaults to torch.get_default_dtype(). out_dtype (torch.dtype, optional): Output tensor data type.
Defaults to torch.get_default_dtype().
""" """
def __init__(self,
activation_quant_key: QuantKey, def __init__(
weight_quant_key: QuantKey, self,
weight:torch.Tensor, activation_quant_key: QuantKey,
weight_scale:torch.Tensor, weight_quant_key: QuantKey,
input_scale:torch.Tensor, weight: torch.Tensor,
out_dtype: torch.dtype = torch.get_default_dtype() weight_scale: torch.Tensor,
): input_scale: torch.Tensor,
out_dtype: torch.dtype | None = None,
):
super().__init__() super().__init__()
self.weight_scale = weight_scale self.weight_scale = weight_scale
self.weight = weight self.weight = weight
self.input_scale = input_scale self.input_scale = input_scale
self.input_scale_ub = None self.input_scale_ub = None
out_dtype = torch.get_default_dtype() if out_dtype is None else out_dtype
self.kernel = init_fp8_linear_kernel( self.kernel = init_fp8_linear_kernel(
activation_quant_key=activation_quant_key, activation_quant_key=activation_quant_key,
weight_quant_key=weight_quant_key, weight_quant_key=weight_quant_key,
out_dtype=out_dtype, out_dtype=out_dtype,
module_name=self.__class__.__name__, module_name=self.__class__.__name__,
) )
def is_quant_fp8_enabled(self) -> bool: def is_quant_fp8_enabled(self) -> bool:
return self.kernel.quant_fp8.enabled() return self.kernel.quant_fp8.enabled()
def forward(self, y: torch.Tensor, bias: torch.Tensor | None=None) -> torch.Tensor: def forward(
self, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
return self.kernel.apply_weights(self, y, bias) return self.kernel.apply_weights(self, y, bias)

View File

@ -30,6 +30,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, GroupShape,
kFp8DynamicTokenSym, kFp8DynamicTokenSym,
kFp8StaticTensorSym, kFp8StaticTensorSym,
kFp8StaticTokenSym,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_block_fp8_supported, cutlass_block_fp8_supported,
@ -51,9 +52,13 @@ strategy_to_parameter_type = {
STATIC_QUANT = True STATIC_QUANT = True
DYNAMIC_QUANT = False DYNAMIC_QUANT = False
quant_keys = { activation_quant_key_mapping = {
STATIC_QUANT: (kFp8StaticTensorSym, kFp8StaticTensorSym), STATIC_QUANT: kFp8StaticTensorSym,
DYNAMIC_QUANT: (kFp8DynamicTokenSym, kFp8StaticTensorSym), DYNAMIC_QUANT: kFp8DynamicTokenSym,
}
weight_quant_key_mapping = {
QuantizationStrategy.CHANNEL: kFp8StaticTokenSym,
QuantizationStrategy.TENSOR: kFp8StaticTensorSym,
} }
logger = init_logger(__name__) logger = init_logger(__name__)
@ -78,7 +83,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
use_aiter_and_is_supported=self.use_aiter_and_is_supported, use_aiter_and_is_supported=self.use_aiter_and_is_supported,
) )
else: else:
activation_quant_key, weight_quant_key = quant_keys[is_static_input_scheme] activation_quant_key = activation_quant_key_mapping[is_static_input_scheme]
weight_quant_key = weight_quant_key_mapping[self.strategy]
self.fp8_linear = init_fp8_linear_kernel( self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=activation_quant_key, activation_quant_key=activation_quant_key,
weight_quant_key=weight_quant_key, weight_quant_key=weight_quant_key,
@ -143,7 +149,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
input_scale = create_fp8_input_scale(output_partition_sizes, weight_loader) input_scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
layer.register_parameter("input_scale", input_scale) layer.register_parameter("input_scale", input_scale)
layer.register_parameter("input_scale_ub", None) layer.input_scale_ub = None
def process_weights_after_loading(self, layer) -> None: def process_weights_after_loading(self, layer) -> None:
if self.strategy == QuantizationStrategy.TENSOR: if self.strategy == QuantizationStrategy.TENSOR:

View File

@ -451,7 +451,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight_loader=weight_loader, weight_loader=weight_loader,
) )
layer.register_parameter("weight", weight) layer.register_parameter("weight", weight)
layer.register_parameter("input_scale_ub", None) layer.input_scale_ub = None
# If checkpoint is serialized fp8, load them. # If checkpoint is serialized fp8, load them.
# Otherwise, wait until process_weights_after_loading. # Otherwise, wait until process_weights_after_loading.

View File

@ -17,9 +17,8 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
CutlassFP8ScaledMMLinearKernel, CutlassFP8ScaledMMLinearKernel,
CutlassScaledMMLinearKernel, CutlassScaledMMLinearKernel,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer import (
FlashInferScaledMMLinearKernel FlashInferScaledMMLinearKernel,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import (
ChannelWiseTorchScaledMMLinearKernel, ChannelWiseTorchScaledMMLinearKernel,
@ -64,7 +63,7 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] =
PerTensorTorchScaledMMLinearKernel, PerTensorTorchScaledMMLinearKernel,
RowWiseTorchScaledMMLinearKernel, RowWiseTorchScaledMMLinearKernel,
ChannelWiseTorchScaledMMLinearKernel, ChannelWiseTorchScaledMMLinearKernel,
], ],
PlatformEnum.ROCM: [ PlatformEnum.ROCM: [
ROCmScaledMMLinearKernel, ROCmScaledMMLinearKernel,
PerTensorTorchScaledMMLinearKernel, PerTensorTorchScaledMMLinearKernel,
@ -164,7 +163,7 @@ def init_fp8_linear_kernel(
logger.info_once( logger.info_once(
"Selected %s for %s", "Selected %s for %s",
kernel_type.__class__.__name__, kernel_type.__name__,
module_name, module_name,
scope="global", scope="global",
) )

View File

@ -174,7 +174,7 @@ class QuarkW8A8Fp8(QuarkScheme):
input_scale[:] = torch.finfo(torch.float32).min input_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", input_scale) layer.register_parameter("input_scale", input_scale)
layer.register_parameter("input_scale_ub", None) layer.input_scale_ub = None
self.fp8_linear = init_fp8_linear_kernel( self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=self.activation_quant_key, activation_quant_key=self.activation_quant_key,