diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index 11ae96e930da7..4d979f075d782 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -20,8 +20,13 @@ from vllm.config import ( ) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + init_fp8_linear_kernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + ScaledMMLinearQuantStrategy, +) from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape -from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.platforms import current_platform @@ -35,21 +40,23 @@ class TestSiluMul(torch.nn.Module): def __init__(self, hidden_size: int = 128): super().__init__() self.silu_and_mul = SiluAndMul() - self.wscale = torch.rand(1, dtype=torch.float32) - self.scale = torch.rand(1, dtype=torch.float32) - + self.weight_scale = torch.rand(1, dtype=torch.float32) + self.input_scale = torch.rand(1, dtype=torch.float32) + self.input_scale_ub = None if TEST_FP8: - self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() - self.fp8_linear = Fp8LinearOp( - act_quant_static=True, - act_quant_group_shape=GroupShape.PER_TENSOR, + self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() + self.fp8_linear = init_fp8_linear_kernel( + act_q_static=True, + act_q_group_shape=GroupShape.PER_TENSOR, + weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, + out_dtype=torch.get_default_dtype(), + module_name=self.__class__.__name__, ) def forward(self, x): y = self.silu_and_mul(x) if TEST_FP8: - x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale) - return x2 + return self.fp8_linear.apply_weights(self, y) else: return y @@ -81,11 +88,19 @@ class TestFusedAddRMSNorm(torch.nn.Module): torch.nn.init.normal_(self.gate_proj, std=0.02) if TEST_FP8: - self.fp8_linear = Fp8LinearOp(act_quant_static=True) - - self.scale = torch.rand(1, dtype=torch.float32) - self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t() - self.wscale = torch.rand(1, dtype=torch.float32) + self.fp8_linear = init_fp8_linear_kernel( + act_q_static=True, + act_q_group_shape=GroupShape.PER_TENSOR, + weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, + out_dtype=torch.get_default_dtype(), + module_name=self.__class__.__name__, + ) + self.weight = ( + torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t() + ) + self.weight_scale = torch.rand(1, dtype=torch.float32) + self.input_scale = torch.rand(1, dtype=torch.float32) + self.input_scale_ub = None def forward(self, hidden_states, residual): # Reshape input @@ -99,13 +114,9 @@ class TestFusedAddRMSNorm(torch.nn.Module): norm_output, residual_output = self.norm(mm, residual) if TEST_FP8: + self.input_scale = self.input_scale.to(norm_output.device) # scaled_mm with static input quantization - fp8_linear_result = self.fp8_linear.apply( - norm_output, - self.w, - self.wscale, - input_scale=self.scale.to(norm_output.device), - ) + fp8_linear_result = self.fp8_linear.apply_weights(self, norm_output) return fp8_linear_result, residual_output diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 286f2276367a0..ed925a4d55cca 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -18,19 +18,24 @@ from vllm.config import ( VllmConfig, ) from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + init_fp8_linear_kernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + ScaledMMLinearQuantStrategy, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, QuantKey, ScaleDesc, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, cutlass_fp8_supported, maybe_create_device_identity, ) from vllm.platforms import current_platform -from ..utils import override_cutlass_fp8_supported +from ..utils import TestFP8Layer, override_cutlass_fp8_supported from .backend import TestBackend FP8_DTYPE = current_platform.fp8_dtype() @@ -54,6 +59,8 @@ class TestModel(torch.nn.Module): self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)] self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN + weight_quant_strategy = ScaledMMLinearQuantStrategy.TENSOR + quant_scale = ScaleDesc(torch.float32, static, group_shape) self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True) if static: @@ -66,9 +73,12 @@ class TestModel(torch.nn.Module): ] with override_cutlass_fp8_supported(not cuda_force_torch): - self.fp8_linear = Fp8LinearOp( - act_quant_static=static, - act_quant_group_shape=group_shape, + self.fp8_linear = init_fp8_linear_kernel( + act_q_static=static, + act_q_group_shape=group_shape, + weight_quant_strategy=weight_quant_strategy, + out_dtype=torch.get_default_dtype(), + module_name=self.__class__.__name__, ) self.enable_rms_norm_custom_op = self.norm[0].enabled() @@ -79,20 +89,20 @@ class TestModel(torch.nn.Module): x = resid = torch.relu(x) y = self.norm[0](x) - x2 = self.fp8_linear.apply( - y, self.w[0], self.wscale[0], input_scale=self.scale[0] - ) + layer1 = TestFP8Layer(self.w[0], self.wscale[0], input_scale=self.scale[0]) + x2 = self.fp8_linear.apply_weights(layer1, y) # make sure resid is used for replacement to work y2, resid = self.norm[1](x2, resid) - x3 = self.fp8_linear.apply( - y2, self.w[1], self.wscale[1], input_scale=self.scale[1] - ) + layer2 = TestFP8Layer(self.w[1], self.wscale[1], input_scale=self.scale[1]) + x3 = self.fp8_linear.apply_weights(layer2, y2) y3, resid = self.norm[2](x3, resid) # use resid here - x4 = self.fp8_linear.apply( - y3, self.w[2], self.wscale[2], input_scale=self.scale[2] + layer3 = TestFP8Layer(self.w[2], self.wscale[2], input_scale=self.scale[2]) + x4 = self.fp8_linear.apply_weights( + layer3, + y3, ) y4, resid = self.norm[3](x4, resid) # use resid here diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index 6d0a0ed7d89d2..2dc6f8d2f925d 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -26,14 +26,19 @@ from vllm.distributed.parallel_state import ( initialize_model_parallel, ) from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + init_fp8_linear_kernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + ScaledMMLinearQuantStrategy, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, GroupShape, ) from vllm.platforms import current_platform from vllm.utils.system_utils import update_environment_variables -from ..utils import has_module_attribute, multi_gpu_test +from ..utils import TestFP8Layer, has_module_attribute, multi_gpu_test from .backend import TestBackend @@ -81,43 +86,49 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): self.eps = eps self.norm = [RMSNorm(hidden_size, eps) for i in range(4)] self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] - self.w = [ + self.input_scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] + self.weight = [ torch.rand(hidden_size, hidden_size) .to(dtype=current_platform.fp8_dtype()) .t() for _ in range(3) ] - self.fp8_linear = Fp8LinearOp( - act_quant_static=True, - act_quant_group_shape=GroupShape.PER_TENSOR, + self.fp8_linear = init_fp8_linear_kernel( + act_q_static=True, + act_q_group_shape=GroupShape.PER_TENSOR, + weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, + out_dtype=torch.get_default_dtype(), + module_name=self.__class__.__name__, ) - self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] - def forward(self, hidden_states): # avoid having graph input be an arg to a pattern directly z = torch.relu(hidden_states) x = resid = tensor_model_parallel_all_reduce(z) y = self.norm[0](x) - z2 = self.fp8_linear.apply( - y, self.w[0], self.wscale[0], input_scale=self.scale[0] + layer1 = TestFP8Layer( + self.weight[0], self.weight_scale[0], input_scale=self.input_scale[0] ) + z2 = self.fp8_linear.apply_weights(layer1, y) x2 = tensor_model_parallel_all_reduce(z2) y2, resid = self.norm[1](x2, resid) - z3 = self.fp8_linear.apply( - y2, self.w[1], self.wscale[1], input_scale=self.scale[1] + layer2 = TestFP8Layer( + self.weight[1], self.weight_scale[1], input_scale=self.input_scale[1] ) + z3 = self.fp8_linear.apply(layer2, y2) x3 = tensor_model_parallel_all_reduce(z3) y3, resid = self.norm[2](x3, resid) # use resid here - z4 = self.fp8_linear.apply( - y3, self.w[2], self.wscale[2], input_scale=self.scale[2] + layer3 = TestFP8Layer( + self.weight[2], self.weight_scale[2], input_scale=self.input_scale[2] ) + z4 = self.fp8_linear.apply(layer3, y3) + x4 = tensor_model_parallel_all_reduce(z4) y4, resid = self.norm[3](x4, resid) # use resid here return y4 diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index fecb1e2e918fe..a6ebf46d98ddb 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -28,16 +28,23 @@ from vllm.config import ( set_current_vllm_config, ) from vllm.forward_context import get_forward_context, set_forward_context +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + init_fp8_linear_kernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + ScaledMMLinearQuantStrategy, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8StaticTensorSym, kNvfp4Quant, ) -from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer from vllm.v1.kv_cache_interface import AttentionSpec +from ..utils import TestFP8Layer + FP8_DTYPE = current_platform.fp8_dtype() FP4_DTYPE = torch.uint8 @@ -170,11 +177,18 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.fp8_linear = Fp8LinearOp( - act_quant_static=self.quant_key.scale.static, - act_quant_group_shape=self.quant_key.scale.group_shape, - ) + if self.quant_key.scale.group_shape.is_per_tensor(): + weight_quant_strategy = ScaledMMLinearQuantStrategy.TENSOR + else: + weight_quant_strategy = ScaledMMLinearQuantStrategy.CHANNEL + self.fp8_linear = init_fp8_linear_kernel( + act_q_static=self.quant_key.scale.static, + act_q_group_shape=self.quant_key.scale.group_shape, + weight_quant_strategy=weight_quant_strategy, + out_dtype=torch.get_default_dtype(), + module_name=self.__class__.__name__, + ) hidden_size = self.num_qo_heads * self.head_size self.w = kwargs.get( "w", @@ -190,12 +204,8 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel): def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): """Forward pass that creates the pattern to be fused.""" attn_output = self.attn(q, k, v) - return self.fp8_linear.apply( - input=attn_output, - weight=self.w["weight"], - weight_scale=self.w["wscale"], - input_scale=self.w["scale"], - ) + layer = TestFP8Layer(self.w["weight"], self.w["wscale"], self.w["scale"]) + return self.fp8_linear.apply_weights(layer, attn_output) class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel): diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index e909cf7393ad3..007339cd86f7b 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -27,11 +27,17 @@ from vllm.distributed.parallel_state import ( initialize_model_parallel, ) from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + init_fp8_linear_kernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + ScaledMMLinearQuantStrategy, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.platforms import current_platform from vllm.utils.system_utils import update_environment_variables -from ..utils import multi_gpu_test +from ..utils import TestFP8Layer, multi_gpu_test from .backend import TestBackend FP8_DTYPE = current_platform.fp8_dtype() @@ -107,8 +113,13 @@ class TestQuantModel(torch.nn.Module): # Initialize weights torch.nn.init.normal_(self.gate_proj, std=0.02) - self.fp8_linear = Fp8LinearOp(act_quant_static=True) - + self.fp8_linear = init_fp8_linear_kernel( + act_q_static=True, + act_q_group_shape=GroupShape.PER_TENSOR, + weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, + out_dtype=torch.get_default_dtype(), + module_name=self.__class__.__name__, + ) self.scale = torch.rand(1, dtype=torch.float32) # Create a weight that is compatible with torch._scaled_mm, # which expects a column-major layout. @@ -138,14 +149,9 @@ class TestQuantModel(torch.nn.Module): # layer normalization norm_output, residual_output = self.norm(all_reduce, residual) - # scaled_mm with static input quantization - fp8_linear_result = self.fp8_linear.apply( - norm_output, - self.w, - self.wscale, - input_scale=self.scale.to(norm_output.device), - ) + layer = TestFP8Layer(None, None, self.scale.to(norm_output.device)) + fp8_linear_result = self.fp8_linear.apply(layer, norm_output) return fp8_linear_result, residual_output diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index 0ddb82b7c3fc2..2ce52b97f13e3 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -24,13 +24,18 @@ from vllm.config import ( set_current_vllm_config, ) from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + init_fp8_linear_kernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + ScaledMMLinearQuantStrategy, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, kFp8StaticTensorSym, kNvfp4Quant, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, maybe_create_device_identity, ) from vllm.platforms import current_platform @@ -50,22 +55,26 @@ class TestSiluMulFp8QuantModel(torch.nn.Module): def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs): super().__init__() self.silu_and_mul = SiluAndMul() - self.wscale = torch.rand(1, dtype=torch.float32) - self.scale = torch.rand(1, dtype=torch.float32) - - self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() + self.weight_scale = torch.rand(1, dtype=torch.float32) + self.input_scale = torch.rand(1, dtype=torch.float32) + self.input_scale_ub = None + self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() with override_cutlass_fp8_supported(not cuda_force_torch): - self.fp8_linear = Fp8LinearOp( - act_quant_static=True, - act_quant_group_shape=GroupShape.PER_TENSOR, + self.fp8_linear = init_fp8_linear_kernel( + act_q_static=True, + act_q_group_shape=GroupShape.PER_TENSOR, + weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, + out_dtype=torch.get_default_dtype(), + module_name=self.__class__.__name__, ) + self.enable_silu_mul_custom_op = self.silu_and_mul.enabled() self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() def forward(self, x): y = self.silu_and_mul(x) - x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale) + x2 = self.fp8_linear.apply_weights(self, y) return x2 def ops_in_model_before(self): diff --git a/tests/utils.py b/tests/utils.py index af4ce6ebaeda2..bb3bbc750350a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1411,3 +1411,14 @@ def flat_product(*iterables: Iterable[Any]): for element in itertools.product(*iterables): normalized = (e if isinstance(e, tuple) else (e,) for e in element) yield tuple(itertools.chain(*normalized)) + + +class TestFP8Layer(torch.nn.Module): + """Helper class for ScaledMMLinearKernels.""" + + def __init__(self, weight, weight_scale, input_scale): + super().__init__() + self.weight_scale = weight_scale + self.weight = weight + self.input_scale = input_scale + self.input_scale_ub = None diff --git a/vllm/config/structured_outputs.py b/vllm/config/structured_outputs.py index 85b6e42264a42..eb1cc7220b8fe 100644 --- a/vllm/config/structured_outputs.py +++ b/vllm/config/structured_outputs.py @@ -2,10 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import hashlib -from typing import Any, Literal, Self +from typing import Any, Literal from pydantic import model_validator from pydantic.dataclasses import dataclass +from typing_extensions import Self from vllm.config.utils import config diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index a7b8e6ddda719..5fa419ebaa91a 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -33,7 +33,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz, ) @@ -97,9 +96,6 @@ class FBGEMMFp8Config(QuantizationConfig): class FBGEMMFp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: FBGEMMFp8Config): self.quant_config = quant_config - self.fp8_linear = Fp8LinearOp( - act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN - ) self.out_dtype = torch.get_default_dtype() self.fp8_linear_kernel = init_fp8_linear_kernel( diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py index 329078f0a489c..f3bff8cae0ef7 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py @@ -41,7 +41,7 @@ class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): weight_quant_strategy: ScaledMMLinearQuantStrategy activation_group_shape: GroupShape - out_dtype: torch.dtype + out_dtype: torch.dtype | None _FP8ParamsT = tuple[ diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 08c1ced5f08d1..901f0649a6d48 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -64,7 +64,7 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] = ], } -_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel, covariant=True) +_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel) _KernelConfigT = TypeVar("_KernelConfigT", bound=ScaledMMLinearLayerConfig) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py index 8323690817d62..62bbacbc782cd 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py @@ -7,11 +7,9 @@ import torch from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.platforms import current_platform -FP8ScaledMMCallBack = Callable[..., torch.Tensor] - def apply_weights_fp8( - scaled_mm_func: FP8ScaledMMCallBack, + scaled_mm_func: Callable[..., torch.Tensor], quant_fp8_func: QuantFP8, w: torch.Tensor, x: torch.Tensor,