From 93fb7071f5ecbdde0c8c03a68f3b0fc692d5e8f0 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 4 Nov 2025 13:10:29 +0000 Subject: [PATCH] reduce test boilerplate Signed-off-by: vllmellm --- tests/compile/test_functionalization.py | 27 ++++---------- tests/compile/test_fusion.py | 30 ++++++---------- tests/compile/test_fusion_all_reduce.py | 36 ++++++++----------- tests/compile/test_fusion_attn.py | 15 +++----- tests/compile/test_sequence_parallelism.py | 16 +++------ tests/compile/test_silu_mul_quant_fusion.py | 21 +++++------ tests/utils.py | 40 +++++++++++++++++++-- 7 files changed, 87 insertions(+), 98 deletions(-) diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index a40f8beccdc20..a10645227383e 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -20,14 +20,12 @@ from vllm.config import ( ) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( - init_fp8_linear_kernel, -) from vllm.model_executor.layers.quantization.utils.quant_utils import ( kFp8StaticTensorSym, ) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.platforms import current_platform +from ..utils import TestFP8Layer from .backend import TestBackend @@ -43,20 +41,14 @@ class TestSiluMul(torch.nn.Module): self.silu_and_mul = SiluAndMul() self.weight_scale = torch.rand(1, dtype=torch.float32) self.input_scale = torch.rand(1, dtype=torch.float32) - self.input_scale_ub = None if TEST_FP8: self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() - self.fp8_linear = init_fp8_linear_kernel( - activation_quant_key=self.quant_key, - weight_quant_key=self.quant_key, - out_dtype=torch.get_default_dtype(), - module_name=self.__class__.__name__, - ) - + self.fp8_linear = TestFP8Layer(self.quant_key, self.quant_key, self.weight, + self.weight_scale, self.input_scale) def forward(self, x): y = self.silu_and_mul(x) if TEST_FP8: - return self.fp8_linear.apply_weights(self, y) + return self.fp8_linear(y) else: return y @@ -90,18 +82,13 @@ class TestFusedAddRMSNorm(torch.nn.Module): torch.nn.init.normal_(self.gate_proj, std=0.02) if TEST_FP8: - self.fp8_linear = init_fp8_linear_kernel( - activation_quant_key=self.quant_key, - weight_quant_key=self.quant_key, - out_dtype=torch.get_default_dtype(), - module_name=self.__class__.__name__, - ) self.weight = ( torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t() ) self.weight_scale = torch.rand(1, dtype=torch.float32) self.input_scale = torch.rand(1, dtype=torch.float32) - self.input_scale_ub = None + self.fp8_linear = TestFP8Layer(self.quant_key, self.quant_key, + self.weight, self.weight_scale, self.input_scale) def forward(self, hidden_states, residual): # Reshape input @@ -117,7 +104,7 @@ class TestFusedAddRMSNorm(torch.nn.Module): if TEST_FP8: self.input_scale = self.input_scale.to(norm_output.device) # scaled_mm with static input quantization - fp8_linear_result = self.fp8_linear.apply_weights(self, norm_output) + fp8_linear_result = self.fp8_linear(norm_output) return fp8_linear_result, residual_output diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 6270344c2eb35..e627c67288cfa 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -18,9 +18,7 @@ from vllm.config import ( VllmConfig, ) from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( - init_fp8_linear_kernel, -) + from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, QuantKey, @@ -76,36 +74,30 @@ class TestModel(torch.nn.Module): ] with override_cutlass_fp8_supported(not cuda_force_torch): - self.fp8_linear = init_fp8_linear_kernel( - activation_quant_key=self.activation_quant_key, - weight_quant_key=self.weight_quant_key, - out_dtype=torch.get_default_dtype(), - module_name=self.__class__.__name__, - ) + self.fp8_linear_1 = TestFP8Layer(self.activation_quant_key, self.weight_quant_key, + self.w[0], self.wscale[0], self.scale[0]) + self.fp8_linear_2 = TestFP8Layer(self.activation_quant_key, self.weight_quant_key, + self.w[1], self.wscale[1], self.scale[1]) + self.fp8_linear_3 = TestFP8Layer(self.activation_quant_key, self.weight_quant_key, + self.w[2], self.wscale[2], self.scale[2]) self.enable_rms_norm_custom_op = self.norm[0].enabled() - self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() + self.enable_quant_fp8_custom_op = self.fp8_linear.is_quant_fp8_enabled() def forward(self, x): # avoid having graph input be an arg to a pattern directly x = resid = torch.relu(x) y = self.norm[0](x) - layer1 = TestFP8Layer(self.w[0], self.wscale[0], input_scale=self.scale[0]) - x2 = self.fp8_linear.apply_weights(layer1, y) + x2 = self.fp8_linear_1(y) # make sure resid is used for replacement to work y2, resid = self.norm[1](x2, resid) - layer2 = TestFP8Layer(self.w[1], self.wscale[1], input_scale=self.scale[1]) - x3 = self.fp8_linear.apply_weights(layer2, y2) + x3 = self.fp8_linear_2(y2) y3, resid = self.norm[2](x3, resid) # use resid here - layer3 = TestFP8Layer(self.w[2], self.wscale[2], input_scale=self.scale[2]) - x4 = self.fp8_linear.apply_weights( - layer3, - y3, - ) + x4 = self.fp8_linear_3(y3) y4, resid = self.norm[3](x4, resid) # use resid here return y4 diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index 5e2c46f8ea919..161d703b79f18 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -26,9 +26,7 @@ from vllm.distributed.parallel_state import ( initialize_model_parallel, ) from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( - init_fp8_linear_kernel, -) + from vllm.model_executor.layers.quantization.utils.quant_utils import ( kFp8StaticTensorSym, ) @@ -93,12 +91,14 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): for _ in range(3) ] - self.fp8_linear = init_fp8_linear_kernel( - activation_quant_key=self.quant_key, - weight_quant_key=self.quant_key, - out_dtype=torch.get_default_dtype(), - module_name=self.__class__.__name__, - ) + self.fp8_linear_1 = TestFP8Layer(self.quant_key,self.quant_key, + self.weight[0],self.wscale[0], input_scale=self.input_scale[0]) + + self.fp8_linear_2 = TestFP8Layer(self.quant_key,self.quant_key, + self.weight[1],self.wscale[1], input_scale=self.input_scale[1]) + + self.fp8_linear_3 = TestFP8Layer(self.quant_key, self.quant_key, + self.weight[2], self.wscale[2],input_scale=self.input_scale[2]) def forward(self, hidden_states): # avoid having graph input be an arg to a pattern directly @@ -106,26 +106,18 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): x = resid = tensor_model_parallel_all_reduce(z) y = self.norm[0](x) - layer1 = TestFP8Layer( - self.weight[0], self.weight_scale[0], input_scale=self.input_scale[0] - ) - z2 = self.fp8_linear.apply_weights(layer1, y) + + z2 = self.fp8_linear_1(y) x2 = tensor_model_parallel_all_reduce(z2) y2, resid = self.norm[1](x2, resid) - layer2 = TestFP8Layer( - self.weight[1], self.weight_scale[1], input_scale=self.input_scale[1] - ) - z3 = self.fp8_linear.apply(layer2, y2) + z3 = self.fp8_linear_2(y2) x3 = tensor_model_parallel_all_reduce(z3) y3, resid = self.norm[2](x3, resid) # use resid here - layer3 = TestFP8Layer( - self.weight[2], self.weight_scale[2], input_scale=self.input_scale[2] - ) - z4 = self.fp8_linear.apply(layer3, y3) + z4 = self.fp8_linear_3(y3) x4 = tensor_model_parallel_all_reduce(z4) y4, resid = self.norm[3](x4, resid) # use resid here @@ -138,7 +130,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): return [ torch.ops.vllm.all_reduce.default, torch.ops._C.static_scaled_fp8_quant.default - if self.fp8_linear.quant_fp8.enabled() + if self.fp8_linear.is_quant_fp8_enabled() else torch.ops.aten.reciprocal.default, ] diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 9068e304f551d..1762af27d190e 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -28,9 +28,7 @@ from vllm.config import ( set_current_vllm_config, ) from vllm.forward_context import get_forward_context, set_forward_context -from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( - init_fp8_linear_kernel, -) + from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8StaticTensorSym, @@ -174,12 +172,6 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.fp8_linear = init_fp8_linear_kernel( - activation_quant_key=self.quant_key, - weight_quant_key=self.quant_key, - out_dtype=torch.get_default_dtype(), - module_name=self.__class__.__name__, - ) hidden_size = self.num_qo_heads * self.head_size self.w = kwargs.get( @@ -192,12 +184,13 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel): "scale": torch.tensor([1.0], dtype=torch.float32, device=self.device), }, ) + self.fp8_linear = TestFP8Layer(self.quant_key, self.quant_key, self.w["weight"], + self.w["wscale"], self.w["scale"]) def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): """Forward pass that creates the pattern to be fused.""" attn_output = self.attn(q, k, v) - layer = TestFP8Layer(self.w["weight"], self.w["wscale"], self.w["scale"]) - return self.fp8_linear.apply_weights(layer, attn_output) + return self.fp8_linear(attn_output) class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel): diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index f579815338a98..0e422f4ee1321 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -27,9 +27,7 @@ from vllm.distributed.parallel_state import ( initialize_model_parallel, ) from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( - init_fp8_linear_kernel, -) + from vllm.model_executor.layers.quantization.utils.quant_utils import ( kFp8StaticTensorSym, ) @@ -114,18 +112,15 @@ class TestQuantModel(torch.nn.Module): # Initialize weights torch.nn.init.normal_(self.gate_proj, std=0.02) - self.fp8_linear = init_fp8_linear_kernel( - activation_quant_key=self.quant_key, - weight_quant_key=self.quant_key, - out_dtype=torch.get_default_dtype(), - module_name=self.__class__.__name__, - ) self.scale = torch.rand(1, dtype=torch.float32) # Create a weight that is compatible with torch._scaled_mm, # which expects a column-major layout. self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t() self.wscale = torch.rand(1, dtype=torch.float32) + self.fp8_linear = TestFP8Layer(self.quant_key, self.quant_key, + self.w, self.wscale, self.scale) + def forward(self, hidden_states, residual): """ Forward pass implementing the operations in the FX graph @@ -150,8 +145,7 @@ class TestQuantModel(torch.nn.Module): # layer normalization norm_output, residual_output = self.norm(all_reduce, residual) # scaled_mm with static input quantization - layer = TestFP8Layer(None, None, self.scale.to(norm_output.device)) - fp8_linear_result = self.fp8_linear.apply(layer, norm_output) + fp8_linear_result = self.fp8_linear(norm_output) return fp8_linear_result, residual_output diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index 20e7c2955d01b..6e6f54a7fbb23 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -24,9 +24,7 @@ from vllm.config import ( set_current_vllm_config, ) from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( - init_fp8_linear_kernel, -) + from vllm.model_executor.layers.quantization.utils.quant_utils import ( kFp8StaticTensorSym, kNvfp4Quant, @@ -36,7 +34,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( ) from vllm.platforms import current_platform -from ..utils import override_cutlass_fp8_supported +from ..utils import TestFP8Layer, override_cutlass_fp8_supported from .backend import TestBackend FP8_DTYPE = current_platform.fp8_dtype() @@ -55,22 +53,19 @@ class TestSiluMulFp8QuantModel(torch.nn.Module): self.silu_and_mul = SiluAndMul() self.weight_scale = torch.rand(1, dtype=torch.float32) self.input_scale = torch.rand(1, dtype=torch.float32) - self.input_scale_ub = None self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() with override_cutlass_fp8_supported(not cuda_force_torch): - self.fp8_linear = init_fp8_linear_kernel( - activation_quant_key=self.quant_key, - weight_quant_key=self.quant_key, - out_dtype=torch.get_default_dtype(), - module_name=self.__class__.__name__, - ) + self.fp8_linear = TestFP8Layer(self.quant_key, self.quant_key, + self.weight, self.weight_scale, self.input_scale) + + self.enable_silu_mul_custom_op = self.silu_and_mul.enabled() - self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() + self.enable_quant_fp8_custom_op = self.fp8_linear.is_quant_fp8_enabled() def forward(self, x): y = self.silu_and_mul(x) - x2 = self.fp8_linear.apply_weights(self, y) + x2 = self.fp8_linear(y) return x2 def ops_in_model_before(self): diff --git a/tests/utils.py b/tests/utils.py index bb3bbc750350a..5c2e10f473182 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -49,6 +49,8 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.mem_constants import GB_bytes from vllm.utils.network_utils import get_open_port from vllm.utils.torch_utils import cuda_device_count_stateless +from vllm.model_executor.layers.quantization.kernels.scaled_mm import init_fp8_linear_kernel +from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey if current_platform.is_rocm(): from amdsmi import ( @@ -1414,11 +1416,45 @@ def flat_product(*iterables: Iterable[Any]): class TestFP8Layer(torch.nn.Module): - """Helper class for ScaledMMLinearKernels.""" + """ + Test helper class for evaluating FP8 linear operations with quantization. - def __init__(self, weight, weight_scale, input_scale): + It supports configurable activation and weight quantization parameters, + and provides a forward method that applies the FP8 linear transformation + with optional bias. + + Args: + activation_quant_key (QuantKey): Key for activation quantization configuration. + weight_quant_key (QuantKey): Key for weight quantization configuration. + weight (torch.Tensor): Weight tensor for linear transformation. + weight_scale (torch.Tensor): Per-tensor or per-group scale for weights. + input_scale (torch.Tensor): Scale tensor for input quantization. + out_dtype (torch.dtype, optional): Output tensor data type. Defaults to torch.get_default_dtype(). + """ + def __init__(self, + activation_quant_key: QuantKey, + weight_quant_key: QuantKey, + weight:torch.Tensor, + weight_scale:torch.Tensor, + input_scale:torch.Tensor, + out_dtype: torch.dtype = torch.get_default_dtype() + ): super().__init__() self.weight_scale = weight_scale self.weight = weight self.input_scale = input_scale self.input_scale_ub = None + + self.kernel = init_fp8_linear_kernel( + activation_quant_key=activation_quant_key, + weight_quant_key=weight_quant_key, + out_dtype=out_dtype, + module_name=self.__class__.__name__, + ) + + def is_quant_fp8_enabled(self) -> bool: + return self.kernel.quant_fp8.enabled() + + def forward(self, y: torch.Tensor, bias: torch.Tensor | None=None) -> torch.Tensor: + return self.kernel.apply_weights(self, y, bias) +