diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 6d27d10f687ab..aa4d2c8cf4537 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -18,18 +18,34 @@ from vllm.config import ( VllmConfig, ) from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( + CutlassFP8ScaledMMLinearKernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer import ( + FlashInferScaledMMLinearKernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import ( + ChannelWiseTorchScaledMMLinearKernel, + PerTensorTorchScaledMMLinearKernel, + RowWiseTorchScaledMMLinearKernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import ( + ROCmScaledMMLinearKernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + FP8ScaledMMLinearKernel, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, QuantKey, ScaleDesc, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - cutlass_fp8_supported, maybe_create_device_identity, ) from vllm.platforms import current_platform -from ..utils import TestFP8Layer, override_cutlass_fp8_supported +from ..utils import TestFP8Layer from .backend import TestBackend FP8_DTYPE = current_platform.fp8_dtype() @@ -44,14 +60,12 @@ class TestModel(torch.nn.Module): hidden_size: int, eps: float, static: bool, - cuda_force_torch: bool, + force_kernel: FP8ScaledMMLinearKernel, *args, **kwargs, ): super().__init__(*args, **kwargs) - self.cuda_force_torch = cuda_force_torch self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)] - self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN act_quant_scale = ScaleDesc(torch.float32, static, group_shape) @@ -67,22 +81,30 @@ class TestModel(torch.nn.Module): self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] else: self.scale = [None for _ in range(3)] + + if group_shape == GroupShape.PER_TOKEN: + self.wscale = [ + torch.rand((hidden_size, 1), dtype=torch.float32) for _ in range(3) + ] + else: + self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] + self.w = [ torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() for _ in range(3) ] - with override_cutlass_fp8_supported(not cuda_force_torch): - self.fp8_linear_layers = [ - TestFP8Layer( - self.activation_quant_key, - self.weight_quant_key, - self.w[i], - self.wscale[i], - input_scale=self.scale[i], - ) - for i in range(3) - ] + self.fp8_linear_layers = [ + TestFP8Layer( + self.activation_quant_key, + self.weight_quant_key, + self.w[i], + self.wscale[i], + input_scale=self.scale[i], + force_kernel=force_kernel, + ) + for i in range(3) + ] self.enable_rms_norm_custom_op = self.norm[0].enabled() self.enable_quant_fp8_custom_op = self.fp8_linear_layers[ @@ -128,6 +150,21 @@ class TestModel(torch.nn.Module): ) +ROCM_FP8_KERNELS = [ + ROCmScaledMMLinearKernel, + PerTensorTorchScaledMMLinearKernel, + RowWiseTorchScaledMMLinearKernel, + ChannelWiseTorchScaledMMLinearKernel, +] + +CUDA_FP8_KERNELS = [ + FlashInferScaledMMLinearKernel, + CutlassFP8ScaledMMLinearKernel, + PerTensorTorchScaledMMLinearKernel, + ChannelWiseTorchScaledMMLinearKernel, +] + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("hidden_size", [64]) @pytest.mark.parametrize("num_tokens", [257]) @@ -135,10 +172,8 @@ class TestModel(torch.nn.Module): @pytest.mark.parametrize("static", [True, False]) @pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False]) @pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False]) -# cuda_force_torch used to test torch code path on platforms that -# cutlass_fp8_supported() == True. @pytest.mark.parametrize( - "cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True] + "force_kernel", CUDA_FP8_KERNELS if current_platform.is_cuda() else ROCM_FP8_KERNELS ) @pytest.mark.skipif( not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm" @@ -151,7 +186,7 @@ def test_fusion_rmsnorm_quant( static, enable_rms_norm_custom_op, enable_quant_fp8_custom_op, - cuda_force_torch, + force_kernel, ): torch.set_default_device("cuda") torch.set_default_dtype(dtype) @@ -179,8 +214,12 @@ def test_fusion_rmsnorm_quant( backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) backend2 = TestBackend(noop_pass, cleanup_pass) - model = TestModel(hidden_size, eps, static, cuda_force_torch) + model = TestModel(hidden_size, eps, static, force_kernel) + # skip the test if we cannot force the kernel + selected_kernels = [layer.kernel for layer in model.fp8_linear_layers] + if not any(isinstance(kernel, force_kernel) for kernel in selected_kernels): + pytest.skip(f"{force_kernel.__name__} couldn't be forced") # First dimension dynamic x = torch.rand(num_tokens, hidden_size) torch._dynamo.mark_dynamic(x, 0) diff --git a/tests/utils.py b/tests/utils.py index ba28886e60795..8fac003b8e9bf 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -45,6 +45,9 @@ from vllm.entrypoints.cli.serve import ServeSubcommand from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( init_fp8_linear_kernel, ) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + FP8ScaledMMLinearKernel, +) from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey from vllm.model_executor.model_loader import get_model_loader from vllm.platforms import current_platform @@ -1443,6 +1446,7 @@ class TestFP8Layer(torch.nn.Module): weight_scale: torch.Tensor, input_scale: torch.Tensor, out_dtype: torch.dtype | None = None, + force_kernel: FP8ScaledMMLinearKernel | None = None, ): super().__init__() self.weight_scale = weight_scale @@ -1454,7 +1458,7 @@ class TestFP8Layer(torch.nn.Module): activation_quant_key=activation_quant_key, weight_quant_key=weight_quant_key, out_dtype=out_dtype, - module_name=self.__class__.__name__, + force_kernel=force_kernel, ) def is_quant_fp8_enabled(self) -> bool: diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 4a3f74f591269..b033cc7905e4e 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -61,7 +61,6 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] = FlashInferScaledMMLinearKernel, CutlassFP8ScaledMMLinearKernel, PerTensorTorchScaledMMLinearKernel, - RowWiseTorchScaledMMLinearKernel, ChannelWiseTorchScaledMMLinearKernel, ], PlatformEnum.ROCM: [ @@ -76,10 +75,38 @@ _KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel) _KernelConfigT = TypeVar("_KernelConfigT", bound=ScaledMMLinearLayerConfig) +def can_implement_scaled_mm_linear_kernel( + kernel: type[_KernelT], config: _KernelConfigT, compute_capability: int | None +) -> tuple[bool, str]: + if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","): + return False, f" {kernel.__name__} disabled by environment variable" + + # If the current platform uses compute_capability, + # make sure the kernel supports the compute cability. + if compute_capability is not None: + kernel_min_capability = kernel.get_min_capability() + if ( + kernel_min_capability is not None + and kernel_min_capability > compute_capability + ): + return ( + False, + f"{kernel.__name__} requires capability " + f"{kernel_min_capability}, current compute capability " + f"is {compute_capability}", + ) + can_implement, failure_reason = kernel.can_implement(config) + if not can_implement: + return (False, f" {kernel.__name__} cannot implement due to: {failure_reason}") + + return True, "" + + def choose_scaled_mm_linear_kernel( config: _KernelConfigT, possible_kernels: dict[PlatformEnum, list[type[_KernelT]]], compute_capability: int | None = None, + force_kernel: type[_KernelT] | None = None, ) -> type[_KernelT]: """ Choose a _KernelT that can implement the given config for the @@ -94,6 +121,9 @@ def choose_scaled_mm_linear_kernel( compute_capability (Optional[int], optional): The compute capability of the target device, if None uses `current_platform` to get the compute capability. Defaults to None. + force_kernel (Optional[type[_KernelT]]): An Optional forced kernel to override + the possible_kernels if it can be implemented. If None, it will only try the + possible kernels. Raises: ValueError: If no kernel can implement the given config. @@ -107,40 +137,32 @@ def choose_scaled_mm_linear_kernel( if _cc is not None: compute_capability = _cc[0] * 10 + _cc[1] - failure_reasons = [] + failure_reason_list = [] + + if force_kernel is not None: + can_implement, failure_reason = can_implement_scaled_mm_linear_kernel( + force_kernel, config, compute_capability + ) + if can_implement: + return force_kernel + + logger.info_once( + "Tried to force %s, but the kernel couldn't be implemented", + force_kernel.__name__, + scope="global", + ) + for kernel in possible_kernels[current_platform._enum]: - if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","): - failure_reasons.append( - f" {kernel.__name__} disabled by environment variable" - ) - continue - - # If the current platform uses compute_capability, - # make sure the kernel supports the compute cability. - if compute_capability is not None: - kernel_min_capability = kernel.get_min_capability() - if ( - kernel_min_capability is not None - and kernel_min_capability > compute_capability - ): - failure_reasons.append( - f"{kernel.__name__} requires capability " - f"{kernel_min_capability}, current compute capability " - f"is {compute_capability}" - ) - continue - - can_implement, failure_reason = kernel.can_implement(config) + can_implement, failure_reason = can_implement_scaled_mm_linear_kernel( + kernel, config, compute_capability + ) if can_implement: return kernel - else: - failure_reasons.append( - f" {kernel.__name__} cannot implement due to: {failure_reason}" - ) + failure_reason_list.append(failure_reason) raise ValueError( "Failed to find a kernel that can implement the " - "ScaledMM linear layer. Reasons: \n" + "\n".join(failure_reasons) + "ScaledMM linear layer. Reasons: \n" + "\n".join(failure_reason_list) ) @@ -148,7 +170,8 @@ def init_fp8_linear_kernel( activation_quant_key: QuantKey, weight_quant_key: QuantKey, out_dtype: torch.dtype, - module_name: str, + force_kernel: type[FP8ScaledMMLinearKernel] | None = None, + module_name: str | None = None, ) -> FP8ScaledMMLinearKernel: scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig( weight_quant_key=weight_quant_key, @@ -157,16 +180,16 @@ def init_fp8_linear_kernel( ) kernel_type = choose_scaled_mm_linear_kernel( - scaled_mm_linear_kernel_config, - _POSSIBLE_FP8_KERNELS, + scaled_mm_linear_kernel_config, _POSSIBLE_FP8_KERNELS, force_kernel=force_kernel ) - logger.info_once( - "Selected %s for %s", - kernel_type.__name__, - module_name, - scope="global", - ) + if module_name: + logger.info_once( + "Selected %s for %s", + kernel_type.__name__, + module_name, + scope="global", + ) return kernel_type( scaled_mm_linear_kernel_config,