diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index a10645227383e..ef8ad92a923e6 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -25,8 +25,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.platforms import current_platform -from ..utils import TestFP8Layer +from ..utils import TestFP8Layer from .backend import TestBackend TEST_FP8 = current_platform.supports_fp8() @@ -43,8 +43,14 @@ class TestSiluMul(torch.nn.Module): self.input_scale = torch.rand(1, dtype=torch.float32) if TEST_FP8: self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() - self.fp8_linear = TestFP8Layer(self.quant_key, self.quant_key, self.weight, - self.weight_scale, self.input_scale) + self.fp8_linear = TestFP8Layer( + self.quant_key, + self.quant_key, + self.weight, + self.weight_scale, + self.input_scale, + ) + def forward(self, x): y = self.silu_and_mul(x) if TEST_FP8: @@ -87,8 +93,13 @@ class TestFusedAddRMSNorm(torch.nn.Module): ) self.weight_scale = torch.rand(1, dtype=torch.float32) self.input_scale = torch.rand(1, dtype=torch.float32) - self.fp8_linear = TestFP8Layer(self.quant_key, self.quant_key, - self.weight, self.weight_scale, self.input_scale) + self.fp8_linear = TestFP8Layer( + self.quant_key, + self.quant_key, + self.weight, + self.weight_scale, + self.input_scale, + ) def forward(self, hidden_states, residual): # Reshape input diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index e627c67288cfa..a8ac8eb576da2 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -18,7 +18,6 @@ from vllm.config import ( VllmConfig, ) from vllm.model_executor.layers.layernorm import RMSNorm - from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, QuantKey, @@ -74,12 +73,27 @@ class TestModel(torch.nn.Module): ] with override_cutlass_fp8_supported(not cuda_force_torch): - self.fp8_linear_1 = TestFP8Layer(self.activation_quant_key, self.weight_quant_key, - self.w[0], self.wscale[0], self.scale[0]) - self.fp8_linear_2 = TestFP8Layer(self.activation_quant_key, self.weight_quant_key, - self.w[1], self.wscale[1], self.scale[1]) - self.fp8_linear_3 = TestFP8Layer(self.activation_quant_key, self.weight_quant_key, - self.w[2], self.wscale[2], self.scale[2]) + self.fp8_linear_1 = TestFP8Layer( + self.activation_quant_key, + self.weight_quant_key, + self.w[0], + self.wscale[0], + self.scale[0], + ) + self.fp8_linear_2 = TestFP8Layer( + self.activation_quant_key, + self.weight_quant_key, + self.w[1], + self.wscale[1], + self.scale[1], + ) + self.fp8_linear_3 = TestFP8Layer( + self.activation_quant_key, + self.weight_quant_key, + self.w[2], + self.wscale[2], + self.scale[2], + ) self.enable_rms_norm_custom_op = self.norm[0].enabled() self.enable_quant_fp8_custom_op = self.fp8_linear.is_quant_fp8_enabled() diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index 161d703b79f18..bda2620d3e2fe 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -26,7 +26,6 @@ from vllm.distributed.parallel_state import ( initialize_model_parallel, ) from vllm.model_executor.layers.layernorm import RMSNorm - from vllm.model_executor.layers.quantization.utils.quant_utils import ( kFp8StaticTensorSym, ) @@ -91,14 +90,29 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): for _ in range(3) ] - self.fp8_linear_1 = TestFP8Layer(self.quant_key,self.quant_key, - self.weight[0],self.wscale[0], input_scale=self.input_scale[0]) - - self.fp8_linear_2 = TestFP8Layer(self.quant_key,self.quant_key, - self.weight[1],self.wscale[1], input_scale=self.input_scale[1]) - - self.fp8_linear_3 = TestFP8Layer(self.quant_key, self.quant_key, - self.weight[2], self.wscale[2],input_scale=self.input_scale[2]) + self.fp8_linear_1 = TestFP8Layer( + self.quant_key, + self.quant_key, + self.weight[0], + self.wscale[0], + input_scale=self.input_scale[0], + ) + + self.fp8_linear_2 = TestFP8Layer( + self.quant_key, + self.quant_key, + self.weight[1], + self.wscale[1], + input_scale=self.input_scale[1], + ) + + self.fp8_linear_3 = TestFP8Layer( + self.quant_key, + self.quant_key, + self.weight[2], + self.wscale[2], + input_scale=self.input_scale[2], + ) def forward(self, hidden_states): # avoid having graph input be an arg to a pattern directly @@ -106,7 +120,6 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): x = resid = tensor_model_parallel_all_reduce(z) y = self.norm[0](x) - z2 = self.fp8_linear_1(y) x2 = tensor_model_parallel_all_reduce(z2) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 1762af27d190e..60e01a0b0b639 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -28,7 +28,6 @@ from vllm.config import ( set_current_vllm_config, ) from vllm.forward_context import get_forward_context, set_forward_context - from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8StaticTensorSym, @@ -172,7 +171,6 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - hidden_size = self.num_qo_heads * self.head_size self.w = kwargs.get( "w", @@ -184,8 +182,13 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel): "scale": torch.tensor([1.0], dtype=torch.float32, device=self.device), }, ) - self.fp8_linear = TestFP8Layer(self.quant_key, self.quant_key, self.w["weight"], - self.w["wscale"], self.w["scale"]) + self.fp8_linear = TestFP8Layer( + self.quant_key, + self.quant_key, + self.w["weight"], + self.w["wscale"], + self.w["scale"], + ) def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): """Forward pass that creates the pattern to be fused.""" diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index 0e422f4ee1321..fc4d38c8f8374 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -27,7 +27,6 @@ from vllm.distributed.parallel_state import ( initialize_model_parallel, ) from vllm.model_executor.layers.layernorm import RMSNorm - from vllm.model_executor.layers.quantization.utils.quant_utils import ( kFp8StaticTensorSym, ) @@ -117,10 +116,10 @@ class TestQuantModel(torch.nn.Module): # which expects a column-major layout. self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t() self.wscale = torch.rand(1, dtype=torch.float32) - self.fp8_linear = TestFP8Layer(self.quant_key, self.quant_key, - self.w, self.wscale, self.scale) + self.fp8_linear = TestFP8Layer( + self.quant_key, self.quant_key, self.w, self.wscale, self.scale + ) - def forward(self, hidden_states, residual): """ Forward pass implementing the operations in the FX graph diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index 6e6f54a7fbb23..56b36856f7f29 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -24,7 +24,6 @@ from vllm.config import ( set_current_vllm_config, ) from vllm.model_executor.layers.activation import SiluAndMul - from vllm.model_executor.layers.quantization.utils.quant_utils import ( kFp8StaticTensorSym, kNvfp4Quant, @@ -56,10 +55,14 @@ class TestSiluMulFp8QuantModel(torch.nn.Module): self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() with override_cutlass_fp8_supported(not cuda_force_torch): - self.fp8_linear = TestFP8Layer(self.quant_key, self.quant_key, - self.weight, self.weight_scale, self.input_scale) + self.fp8_linear = TestFP8Layer( + self.quant_key, + self.quant_key, + self.weight, + self.weight_scale, + self.input_scale, + ) - self.enable_silu_mul_custom_op = self.silu_and_mul.enabled() self.enable_quant_fp8_custom_op = self.fp8_linear.is_quant_fp8_enabled() diff --git a/tests/utils.py b/tests/utils.py index 5c2e10f473182..ba28886e60795 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -42,6 +42,10 @@ from vllm.distributed import ( ) from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.cli.serve import ServeSubcommand +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + init_fp8_linear_kernel, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey from vllm.model_executor.model_loader import get_model_loader from vllm.platforms import current_platform from vllm.transformers_utils.tokenizer import get_tokenizer @@ -49,8 +53,6 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.mem_constants import GB_bytes from vllm.utils.network_utils import get_open_port from vllm.utils.torch_utils import cuda_device_count_stateless -from vllm.model_executor.layers.quantization.kernels.scaled_mm import init_fp8_linear_kernel -from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey if current_platform.is_rocm(): from amdsmi import ( @@ -1429,32 +1431,36 @@ class TestFP8Layer(torch.nn.Module): weight (torch.Tensor): Weight tensor for linear transformation. weight_scale (torch.Tensor): Per-tensor or per-group scale for weights. input_scale (torch.Tensor): Scale tensor for input quantization. - out_dtype (torch.dtype, optional): Output tensor data type. Defaults to torch.get_default_dtype(). + out_dtype (torch.dtype, optional): Output tensor data type. + Defaults to torch.get_default_dtype(). """ - def __init__(self, - activation_quant_key: QuantKey, - weight_quant_key: QuantKey, - weight:torch.Tensor, - weight_scale:torch.Tensor, - input_scale:torch.Tensor, - out_dtype: torch.dtype = torch.get_default_dtype() - ): + + def __init__( + self, + activation_quant_key: QuantKey, + weight_quant_key: QuantKey, + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: torch.Tensor, + out_dtype: torch.dtype | None = None, + ): super().__init__() self.weight_scale = weight_scale self.weight = weight self.input_scale = input_scale self.input_scale_ub = None - + out_dtype = torch.get_default_dtype() if out_dtype is None else out_dtype self.kernel = init_fp8_linear_kernel( activation_quant_key=activation_quant_key, weight_quant_key=weight_quant_key, out_dtype=out_dtype, module_name=self.__class__.__name__, ) - + def is_quant_fp8_enabled(self) -> bool: return self.kernel.quant_fp8.enabled() - def forward(self, y: torch.Tensor, bias: torch.Tensor | None=None) -> torch.Tensor: + def forward( + self, y: torch.Tensor, bias: torch.Tensor | None = None + ) -> torch.Tensor: return self.kernel.apply_weights(self, y, bias) - diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 0d14c13180ab1..a1c60fadce6d6 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -30,6 +30,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, kFp8DynamicTokenSym, kFp8StaticTensorSym, + kFp8StaticTokenSym, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( cutlass_block_fp8_supported, @@ -51,9 +52,13 @@ strategy_to_parameter_type = { STATIC_QUANT = True DYNAMIC_QUANT = False -quant_keys = { - STATIC_QUANT: (kFp8StaticTensorSym, kFp8StaticTensorSym), - DYNAMIC_QUANT: (kFp8DynamicTokenSym, kFp8StaticTensorSym), +activation_quant_key_mapping = { + STATIC_QUANT: kFp8StaticTensorSym, + DYNAMIC_QUANT: kFp8DynamicTokenSym, +} +weight_quant_key_mapping = { + QuantizationStrategy.CHANNEL: kFp8StaticTokenSym, + QuantizationStrategy.TENSOR: kFp8StaticTensorSym, } logger = init_logger(__name__) @@ -78,7 +83,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): use_aiter_and_is_supported=self.use_aiter_and_is_supported, ) else: - activation_quant_key, weight_quant_key = quant_keys[is_static_input_scheme] + activation_quant_key = activation_quant_key_mapping[is_static_input_scheme] + weight_quant_key = weight_quant_key_mapping[self.strategy] self.fp8_linear = init_fp8_linear_kernel( activation_quant_key=activation_quant_key, weight_quant_key=weight_quant_key, @@ -143,7 +149,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): input_scale = create_fp8_input_scale(output_partition_sizes, weight_loader) layer.register_parameter("input_scale", input_scale) - layer.register_parameter("input_scale_ub", None) + layer.input_scale_ub = None def process_weights_after_loading(self, layer) -> None: if self.strategy == QuantizationStrategy.TENSOR: diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index c04bcef7bb0b5..6b613b21066a6 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -451,7 +451,7 @@ class Fp8LinearMethod(LinearMethodBase): weight_loader=weight_loader, ) layer.register_parameter("weight", weight) - layer.register_parameter("input_scale_ub", None) + layer.input_scale_ub = None # If checkpoint is serialized fp8, load them. # Otherwise, wait until process_weights_after_loading. diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 67d0772895786..4a3f74f591269 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -17,9 +17,8 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( CutlassFP8ScaledMMLinearKernel, CutlassScaledMMLinearKernel, ) - from vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer import ( - FlashInferScaledMMLinearKernel + FlashInferScaledMMLinearKernel, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import ( ChannelWiseTorchScaledMMLinearKernel, @@ -64,7 +63,7 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] = PerTensorTorchScaledMMLinearKernel, RowWiseTorchScaledMMLinearKernel, ChannelWiseTorchScaledMMLinearKernel, - ], + ], PlatformEnum.ROCM: [ ROCmScaledMMLinearKernel, PerTensorTorchScaledMMLinearKernel, @@ -164,7 +163,7 @@ def init_fp8_linear_kernel( logger.info_once( "Selected %s for %s", - kernel_type.__class__.__name__, + kernel_type.__name__, module_name, scope="global", ) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index 343539c10fa87..819348c5b938e 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -174,7 +174,7 @@ class QuarkW8A8Fp8(QuarkScheme): input_scale[:] = torch.finfo(torch.float32).min layer.register_parameter("input_scale", input_scale) - layer.register_parameter("input_scale_ub", None) + layer.input_scale_ub = None self.fp8_linear = init_fp8_linear_kernel( activation_quant_key=self.activation_quant_key,