update unit tests to use ScaledMMLinearKernels

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm 2025-11-01 16:28:03 +00:00
parent 4ce0ba2df4
commit dd5a70ec71
12 changed files with 152 additions and 89 deletions

View File

@ -20,8 +20,13 @@ from vllm.config import (
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform
@ -35,21 +40,23 @@ class TestSiluMul(torch.nn.Module):
def __init__(self, hidden_size: int = 128):
super().__init__()
self.silu_and_mul = SiluAndMul()
self.wscale = torch.rand(1, dtype=torch.float32)
self.scale = torch.rand(1, dtype=torch.float32)
self.weight_scale = torch.rand(1, dtype=torch.float32)
self.input_scale = torch.rand(1, dtype=torch.float32)
self.input_scale_ub = None
if TEST_FP8:
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
self.fp8_linear = Fp8LinearOp(
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
self.fp8_linear = init_fp8_linear_kernel(
act_q_static=True,
act_q_group_shape=GroupShape.PER_TENSOR,
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
def forward(self, x):
y = self.silu_and_mul(x)
if TEST_FP8:
x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale)
return x2
return self.fp8_linear.apply_weights(self, y)
else:
return y
@ -81,11 +88,19 @@ class TestFusedAddRMSNorm(torch.nn.Module):
torch.nn.init.normal_(self.gate_proj, std=0.02)
if TEST_FP8:
self.fp8_linear = Fp8LinearOp(act_quant_static=True)
self.scale = torch.rand(1, dtype=torch.float32)
self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
self.wscale = torch.rand(1, dtype=torch.float32)
self.fp8_linear = init_fp8_linear_kernel(
act_q_static=True,
act_q_group_shape=GroupShape.PER_TENSOR,
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
self.weight = (
torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
)
self.weight_scale = torch.rand(1, dtype=torch.float32)
self.input_scale = torch.rand(1, dtype=torch.float32)
self.input_scale_ub = None
def forward(self, hidden_states, residual):
# Reshape input
@ -99,13 +114,9 @@ class TestFusedAddRMSNorm(torch.nn.Module):
norm_output, residual_output = self.norm(mm, residual)
if TEST_FP8:
self.input_scale = self.input_scale.to(norm_output.device)
# scaled_mm with static input quantization
fp8_linear_result = self.fp8_linear.apply(
norm_output,
self.w,
self.wscale,
input_scale=self.scale.to(norm_output.device),
)
fp8_linear_result = self.fp8_linear.apply_weights(self, norm_output)
return fp8_linear_result, residual_output

View File

@ -18,19 +18,24 @@ from vllm.config import (
VllmConfig,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
QuantKey,
ScaleDesc,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
cutlass_fp8_supported,
maybe_create_device_identity,
)
from vllm.platforms import current_platform
from ..utils import override_cutlass_fp8_supported
from ..utils import TestFP8Layer, override_cutlass_fp8_supported
from .backend import TestBackend
FP8_DTYPE = current_platform.fp8_dtype()
@ -54,6 +59,8 @@ class TestModel(torch.nn.Module):
self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
weight_quant_strategy = ScaledMMLinearQuantStrategy.TENSOR
quant_scale = ScaleDesc(torch.float32, static, group_shape)
self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
if static:
@ -66,9 +73,12 @@ class TestModel(torch.nn.Module):
]
with override_cutlass_fp8_supported(not cuda_force_torch):
self.fp8_linear = Fp8LinearOp(
act_quant_static=static,
act_quant_group_shape=group_shape,
self.fp8_linear = init_fp8_linear_kernel(
act_q_static=static,
act_q_group_shape=group_shape,
weight_quant_strategy=weight_quant_strategy,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
self.enable_rms_norm_custom_op = self.norm[0].enabled()
@ -79,20 +89,20 @@ class TestModel(torch.nn.Module):
x = resid = torch.relu(x)
y = self.norm[0](x)
x2 = self.fp8_linear.apply(
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
)
layer1 = TestFP8Layer(self.w[0], self.wscale[0], input_scale=self.scale[0])
x2 = self.fp8_linear.apply_weights(layer1, y)
# make sure resid is used for replacement to work
y2, resid = self.norm[1](x2, resid)
x3 = self.fp8_linear.apply(
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
)
layer2 = TestFP8Layer(self.w[1], self.wscale[1], input_scale=self.scale[1])
x3 = self.fp8_linear.apply_weights(layer2, y2)
y3, resid = self.norm[2](x3, resid) # use resid here
x4 = self.fp8_linear.apply(
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
layer3 = TestFP8Layer(self.w[2], self.wscale[2], input_scale=self.scale[2])
x4 = self.fp8_linear.apply_weights(
layer3,
y3,
)
y4, resid = self.norm[3](x4, resid) # use resid here

View File

@ -26,14 +26,19 @@ from vllm.distributed.parallel_state import (
initialize_model_parallel,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
GroupShape,
)
from vllm.platforms import current_platform
from vllm.utils.system_utils import update_environment_variables
from ..utils import has_module_attribute, multi_gpu_test
from ..utils import TestFP8Layer, has_module_attribute, multi_gpu_test
from .backend import TestBackend
@ -81,43 +86,49 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
self.eps = eps
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
self.w = [
self.input_scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
self.weight = [
torch.rand(hidden_size, hidden_size)
.to(dtype=current_platform.fp8_dtype())
.t()
for _ in range(3)
]
self.fp8_linear = Fp8LinearOp(
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
self.fp8_linear = init_fp8_linear_kernel(
act_q_static=True,
act_q_group_shape=GroupShape.PER_TENSOR,
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
def forward(self, hidden_states):
# avoid having graph input be an arg to a pattern directly
z = torch.relu(hidden_states)
x = resid = tensor_model_parallel_all_reduce(z)
y = self.norm[0](x)
z2 = self.fp8_linear.apply(
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
layer1 = TestFP8Layer(
self.weight[0], self.weight_scale[0], input_scale=self.input_scale[0]
)
z2 = self.fp8_linear.apply_weights(layer1, y)
x2 = tensor_model_parallel_all_reduce(z2)
y2, resid = self.norm[1](x2, resid)
z3 = self.fp8_linear.apply(
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
layer2 = TestFP8Layer(
self.weight[1], self.weight_scale[1], input_scale=self.input_scale[1]
)
z3 = self.fp8_linear.apply(layer2, y2)
x3 = tensor_model_parallel_all_reduce(z3)
y3, resid = self.norm[2](x3, resid) # use resid here
z4 = self.fp8_linear.apply(
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
layer3 = TestFP8Layer(
self.weight[2], self.weight_scale[2], input_scale=self.input_scale[2]
)
z4 = self.fp8_linear.apply(layer3, y3)
x4 = tensor_model_parallel_all_reduce(z4)
y4, resid = self.norm[3](x4, resid) # use resid here
return y4

View File

@ -28,16 +28,23 @@ from vllm.config import (
set_current_vllm_config,
)
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
kNvfp4Quant,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer
from vllm.v1.kv_cache_interface import AttentionSpec
from ..utils import TestFP8Layer
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
@ -170,11 +177,18 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.quant_key.scale.static,
act_quant_group_shape=self.quant_key.scale.group_shape,
)
if self.quant_key.scale.group_shape.is_per_tensor():
weight_quant_strategy = ScaledMMLinearQuantStrategy.TENSOR
else:
weight_quant_strategy = ScaledMMLinearQuantStrategy.CHANNEL
self.fp8_linear = init_fp8_linear_kernel(
act_q_static=self.quant_key.scale.static,
act_q_group_shape=self.quant_key.scale.group_shape,
weight_quant_strategy=weight_quant_strategy,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
hidden_size = self.num_qo_heads * self.head_size
self.w = kwargs.get(
"w",
@ -190,12 +204,8 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""Forward pass that creates the pattern to be fused."""
attn_output = self.attn(q, k, v)
return self.fp8_linear.apply(
input=attn_output,
weight=self.w["weight"],
weight_scale=self.w["wscale"],
input_scale=self.w["scale"],
)
layer = TestFP8Layer(self.w["weight"], self.w["wscale"], self.w["scale"])
return self.fp8_linear.apply_weights(layer, attn_output)
class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):

View File

@ -27,11 +27,17 @@ from vllm.distributed.parallel_state import (
initialize_model_parallel,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from vllm.utils.system_utils import update_environment_variables
from ..utils import multi_gpu_test
from ..utils import TestFP8Layer, multi_gpu_test
from .backend import TestBackend
FP8_DTYPE = current_platform.fp8_dtype()
@ -107,8 +113,13 @@ class TestQuantModel(torch.nn.Module):
# Initialize weights
torch.nn.init.normal_(self.gate_proj, std=0.02)
self.fp8_linear = Fp8LinearOp(act_quant_static=True)
self.fp8_linear = init_fp8_linear_kernel(
act_q_static=True,
act_q_group_shape=GroupShape.PER_TENSOR,
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
self.scale = torch.rand(1, dtype=torch.float32)
# Create a weight that is compatible with torch._scaled_mm,
# which expects a column-major layout.
@ -138,14 +149,9 @@ class TestQuantModel(torch.nn.Module):
# layer normalization
norm_output, residual_output = self.norm(all_reduce, residual)
# scaled_mm with static input quantization
fp8_linear_result = self.fp8_linear.apply(
norm_output,
self.w,
self.wscale,
input_scale=self.scale.to(norm_output.device),
)
layer = TestFP8Layer(None, None, self.scale.to(norm_output.device))
fp8_linear_result = self.fp8_linear.apply(layer, norm_output)
return fp8_linear_result, residual_output

View File

@ -24,13 +24,18 @@ from vllm.config import (
set_current_vllm_config,
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
kFp8StaticTensorSym,
kNvfp4Quant,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
maybe_create_device_identity,
)
from vllm.platforms import current_platform
@ -50,22 +55,26 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs):
super().__init__()
self.silu_and_mul = SiluAndMul()
self.wscale = torch.rand(1, dtype=torch.float32)
self.scale = torch.rand(1, dtype=torch.float32)
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
self.weight_scale = torch.rand(1, dtype=torch.float32)
self.input_scale = torch.rand(1, dtype=torch.float32)
self.input_scale_ub = None
self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
with override_cutlass_fp8_supported(not cuda_force_torch):
self.fp8_linear = Fp8LinearOp(
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
self.fp8_linear = init_fp8_linear_kernel(
act_q_static=True,
act_q_group_shape=GroupShape.PER_TENSOR,
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
def forward(self, x):
y = self.silu_and_mul(x)
x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale)
x2 = self.fp8_linear.apply_weights(self, y)
return x2
def ops_in_model_before(self):

View File

@ -1411,3 +1411,14 @@ def flat_product(*iterables: Iterable[Any]):
for element in itertools.product(*iterables):
normalized = (e if isinstance(e, tuple) else (e,) for e in element)
yield tuple(itertools.chain(*normalized))
class TestFP8Layer(torch.nn.Module):
"""Helper class for ScaledMMLinearKernels."""
def __init__(self, weight, weight_scale, input_scale):
super().__init__()
self.weight_scale = weight_scale
self.weight = weight
self.input_scale = input_scale
self.input_scale_ub = None

View File

@ -2,10 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from typing import Any, Literal, Self
from typing import Any, Literal
from pydantic import model_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self
from vllm.config.utils import config

View File

@ -33,7 +33,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
maybe_create_device_identity,
normalize_e4m3fn_to_e4m3fnuz,
)
@ -97,9 +96,6 @@ class FBGEMMFp8Config(QuantizationConfig):
class FBGEMMFp8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: FBGEMMFp8Config):
self.quant_config = quant_config
self.fp8_linear = Fp8LinearOp(
act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN
)
self.out_dtype = torch.get_default_dtype()
self.fp8_linear_kernel = init_fp8_linear_kernel(

View File

@ -41,7 +41,7 @@ class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
weight_quant_strategy: ScaledMMLinearQuantStrategy
activation_group_shape: GroupShape
out_dtype: torch.dtype
out_dtype: torch.dtype | None
_FP8ParamsT = tuple[

View File

@ -64,7 +64,7 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] =
],
}
_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel, covariant=True)
_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel)
_KernelConfigT = TypeVar("_KernelConfigT", bound=ScaledMMLinearLayerConfig)

View File

@ -7,11 +7,9 @@ import torch
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.platforms import current_platform
FP8ScaledMMCallBack = Callable[..., torch.Tensor]
def apply_weights_fp8(
scaled_mm_func: FP8ScaledMMCallBack,
scaled_mm_func: Callable[..., torch.Tensor],
quant_fp8_func: QuantFP8,
w: torch.Tensor,
x: torch.Tensor,