diff --git a/tests/compile/distributed/test_fusion_all_reduce.py b/tests/compile/distributed/test_fusion_all_reduce.py index fc8d1f98ebf87..92abb90ef5dfc 100644 --- a/tests/compile/distributed/test_fusion_all_reduce.py +++ b/tests/compile/distributed/test_fusion_all_reduce.py @@ -26,14 +26,13 @@ from vllm.distributed.parallel_state import ( initialize_model_parallel, ) from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, - GroupShape, +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kFp8StaticTensorSym, ) from vllm.platforms import current_platform from vllm.utils.system_utils import update_environment_variables -from ...utils import has_module_attribute, multi_gpu_test +from ...utils import TestFP8Layer, has_module_attribute, multi_gpu_test from ..backend import TestBackend @@ -75,25 +74,32 @@ class TestAllReduceRMSNormModel(torch.nn.Module): class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): + quant_key = kFp8StaticTensorSym + def __init__(self, hidden_size=16, token_num=16, eps=1e-6): super().__init__() self.hidden_size = hidden_size self.eps = eps self.norm = [RMSNorm(hidden_size, eps) for i in range(4)] self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] - self.w = [ + self.input_scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] + self.weight = [ torch.rand(hidden_size, hidden_size) .to(dtype=current_platform.fp8_dtype()) .t() for _ in range(3) ] - self.fp8_linear = Fp8LinearOp( - act_quant_static=True, - act_quant_group_shape=GroupShape.PER_TENSOR, - ) - - self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] + self.fp8_linear_layers = [ + TestFP8Layer( + self.quant_key, + self.quant_key, + self.weight[i], + self.wscale[i], + input_scale=self.input_scale[i], + ) + for i in range(3) + ] def forward(self, hidden_states): # avoid having graph input be an arg to a pattern directly @@ -101,23 +107,18 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): x = resid = tensor_model_parallel_all_reduce(z) y = self.norm[0](x) - z2 = self.fp8_linear.apply( - y, self.w[0], self.wscale[0], input_scale=self.scale[0] - ) + z2 = self.fp8_linear_layers[0](y) x2 = tensor_model_parallel_all_reduce(z2) y2, resid = self.norm[1](x2, resid) - z3 = self.fp8_linear.apply( - y2, self.w[1], self.wscale[1], input_scale=self.scale[1] - ) + z3 = self.fp8_linear_layers[1](y2) x3 = tensor_model_parallel_all_reduce(z3) y3, resid = self.norm[2](x3, resid) # use resid here - z4 = self.fp8_linear.apply( - y3, self.w[2], self.wscale[2], input_scale=self.scale[2] - ) + z4 = self.fp8_linear_layers[2](y3) + x4 = tensor_model_parallel_all_reduce(z4) y4, resid = self.norm[3](x4, resid) # use resid here return y4 @@ -129,7 +130,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): return [ torch.ops.vllm.all_reduce.default, torch.ops._C.static_scaled_fp8_quant.default - if self.fp8_linear.quant_fp8.enabled() + if self.fp8_linear_layers[0].is_quant_fp8_enabled() else torch.ops.aten.reciprocal.default, ] diff --git a/tests/compile/distributed/test_sequence_parallelism.py b/tests/compile/distributed/test_sequence_parallelism.py index d9fdc3acc3d6f..13329f3cda306 100644 --- a/tests/compile/distributed/test_sequence_parallelism.py +++ b/tests/compile/distributed/test_sequence_parallelism.py @@ -27,12 +27,13 @@ from vllm.distributed.parallel_state import ( initialize_model_parallel, ) from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape -from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kFp8StaticTensorSym, +) from vllm.platforms import current_platform from vllm.utils.system_utils import update_environment_variables -from ...utils import multi_gpu_test +from ...utils import TestFP8Layer, multi_gpu_test from ..backend import TestBackend FP8_DTYPE = current_platform.fp8_dtype() @@ -93,6 +94,8 @@ class TestAllReduceRMSNormModel(torch.nn.Module): class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): + quant_key = kFp8StaticTensorSym + def __init__(self, hidden_size=16, eps=1e-6): super().__init__() self.vllm_config = get_current_vllm_config() @@ -106,37 +109,32 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): .t() for _ in range(3) ] - - self.fp8_linear = Fp8LinearOp( - act_quant_static=True, - act_quant_group_shape=GroupShape.PER_TENSOR, - ) - self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] + self.fp8_linear_layers = [ + TestFP8Layer( + self.quant_key, self.quant_key, self.w[i], self.wscale[i], self.scale[i] + ) + for i in range(3) + ] + def forward(self, hidden_states): # avoid having graph input be an arg to a pattern directly z = torch.relu(hidden_states) x = resid = tensor_model_parallel_all_reduce(z) y = self.norm[0](x) - z2 = self.fp8_linear.apply( - y, self.w[0], self.wscale[0], input_scale=self.scale[0] - ) + z2 = self.fp8_linear_layers[0](y) x2 = tensor_model_parallel_all_reduce(z2) y2, resid = self.norm[1](x2, resid) - z3 = self.fp8_linear.apply( - y2, self.w[1], self.wscale[1], input_scale=self.scale[1] - ) + z3 = self.fp8_linear_layers[1](y2) x3 = tensor_model_parallel_all_reduce(z3) y3, resid = self.norm[2](x3, resid) # use resid here - z4 = self.fp8_linear.apply( - y3, self.w[2], self.wscale[2], input_scale=self.scale[2] - ) + z4 = self.fp8_linear_layers[2](y3) x4 = tensor_model_parallel_all_reduce(z4) y4, resid = self.norm[3](x4, resid) # use resid here return y4 @@ -159,7 +157,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): return [ torch.ops._C.fused_add_rms_norm.default, ] - elif self.fp8_linear.quant_fp8.enabled(): + elif any(layer.is_quant_fp8_enabled() for layer in self.fp8_linear_layers): return [ torch.ops._C.static_scaled_fp8_quant.default, ] diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index ad5ead36e2310..bfd59bac54ca4 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -20,11 +20,13 @@ from vllm.config import ( ) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape -from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kFp8StaticTensorSym, +) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.platforms import current_platform +from ..utils import TestFP8Layer from .backend import TestBackend TEST_FP8 = current_platform.supports_fp8() @@ -32,24 +34,27 @@ FP8_DTYPE = current_platform.fp8_dtype() class TestSiluMul(torch.nn.Module): + quant_key = kFp8StaticTensorSym + def __init__(self, hidden_size: int = 128): super().__init__() self.silu_and_mul = SiluAndMul() - self.wscale = torch.rand(1, dtype=torch.float32) - self.scale = torch.rand(1, dtype=torch.float32) - + self.weight_scale = torch.rand(1, dtype=torch.float32) + self.input_scale = torch.rand(1, dtype=torch.float32) if TEST_FP8: - self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() - self.fp8_linear = Fp8LinearOp( - act_quant_static=True, - act_quant_group_shape=GroupShape.PER_TENSOR, + self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() + self.fp8_linear = TestFP8Layer( + self.quant_key, + self.quant_key, + self.weight, + self.weight_scale, + self.input_scale, ) def forward(self, x): y = self.silu_and_mul(x) if TEST_FP8: - x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale) - return x2 + return self.fp8_linear(y) else: return y @@ -67,6 +72,8 @@ class TestSiluMul(torch.nn.Module): class TestFusedAddRMSNorm(torch.nn.Module): + quant_key = kFp8StaticTensorSym + def __init__(self, hidden_size=16, intermediate_size=32): super().__init__() self.hidden_size = hidden_size @@ -81,11 +88,18 @@ class TestFusedAddRMSNorm(torch.nn.Module): torch.nn.init.normal_(self.gate_proj, std=0.02) if TEST_FP8: - self.fp8_linear = Fp8LinearOp(act_quant_static=True) - - self.scale = torch.rand(1, dtype=torch.float32) - self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t() - self.wscale = torch.rand(1, dtype=torch.float32) + self.weight = ( + torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t() + ) + self.weight_scale = torch.rand(1, dtype=torch.float32) + self.input_scale = torch.rand(1, dtype=torch.float32) + self.fp8_linear = TestFP8Layer( + self.quant_key, + self.quant_key, + self.weight, + self.weight_scale, + self.input_scale, + ) def forward(self, hidden_states, residual): # Reshape input @@ -99,13 +113,9 @@ class TestFusedAddRMSNorm(torch.nn.Module): norm_output, residual_output = self.norm(mm, residual) if TEST_FP8: + self.input_scale = self.input_scale.to(norm_output.device) # scaled_mm with static input quantization - fp8_linear_result = self.fp8_linear.apply( - norm_output, - self.w, - self.wscale, - input_scale=self.scale.to(norm_output.device), - ) + fp8_linear_result = self.fp8_linear(norm_output) return fp8_linear_result, residual_output diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 7755e9f9b7380..ac92fb13fdc09 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -5,6 +5,7 @@ import pytest import torch +import vllm.config import vllm.plugins from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass @@ -20,8 +21,22 @@ from vllm.config import ( VllmConfig, ) from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - W8A8BlockFp8LinearOp, +from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( + CutlassFP8ScaledMMLinearKernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer import ( + FlashInferScaledMMLinearKernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import ( + ChannelWiseTorchScaledMMLinearKernel, + PerTensorTorchScaledMMLinearKernel, + RowWiseTorchScaledMMLinearKernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import ( + ROCmScaledMMLinearKernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + FP8ScaledMMLinearKernel, ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, @@ -29,15 +44,14 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ScaleDesc, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, cutlass_block_fp8_supported, - cutlass_fp8_supported, - maybe_create_device_identity, ) from vllm.platforms import current_platform -from vllm.utils.deep_gemm import is_deep_gemm_supported +from vllm.utils.deep_gemm import ( + is_deep_gemm_supported, +) -from ..utils import override_cutlass_fp8_supported +from ..utils import TestBlockFP8Layer, TestFP8Layer from .backend import TestBackend FP8_DTYPE = current_platform.fp8_dtype() @@ -45,157 +59,260 @@ FP8_DTYPE = current_platform.fp8_dtype() RMS_OP = torch.ops._C.rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default +# Kernel and group_shape combinations: (kernel, group_shape) +# CUDA kernels +CUDA_KERNEL_GROUPSHAPE_COMBINATIONS = [ + # FlashInferScaledMMLinearKernel supports both per-tensor and per-token + (FlashInferScaledMMLinearKernel, GroupShape.PER_TOKEN), + (FlashInferScaledMMLinearKernel, GroupShape.PER_TENSOR), + # CutlassFP8ScaledMMLinearKernel supports both per-tensor and per-token + (CutlassFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN), + (CutlassFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR), + # PerTensorTorchScaledMMLinearKernel only supports per-tensor + (PerTensorTorchScaledMMLinearKernel, GroupShape.PER_TENSOR), + # ChannelWiseTorchScaledMMLinearKernel only supports per-token + (ChannelWiseTorchScaledMMLinearKernel, GroupShape.PER_TOKEN), + # Blockwise group shapes (no kernel abstraction) + (None, GroupShape(1, 128)), + (None, GroupShape(1, 64)), +] + +# ROCm kernels +ROCM_KERNEL_GROUPSHAPE_COMBINATIONS = [ + # ROCmScaledMMLinearKernel supports both per-tensor and per-token + (ROCmScaledMMLinearKernel, GroupShape.PER_TOKEN), + (ROCmScaledMMLinearKernel, GroupShape.PER_TENSOR), + # RowWiseTorchScaledMMLinearKernel only supports per-token + (RowWiseTorchScaledMMLinearKernel, GroupShape.PER_TOKEN), + # ChannelWiseTorchScaledMMLinearKernel only supports per-token + (ChannelWiseTorchScaledMMLinearKernel, GroupShape.PER_TOKEN), + # Blockwise group shapes (no kernel abstraction) + (None, GroupShape(1, 128)), + (None, GroupShape(1, 64)), +] + +KERNEL_GROUPSHAPE_COMBINATIONS = ( + CUDA_KERNEL_GROUPSHAPE_COMBINATIONS + if current_platform.is_cuda() + else ROCM_KERNEL_GROUPSHAPE_COMBINATIONS +) + +# For Aiter tests we toggle use_aiter_quant_op +AITER_KERNEL_GROUPSHAPE_COMBINATIONS = [ + # Per-token with ROCmScaledMMLinearKernel + (ROCmScaledMMLinearKernel, GroupShape.PER_TOKEN, True), + (ROCmScaledMMLinearKernel, GroupShape.PER_TOKEN, False), + # Per-token with RowWiseTorchScaledMMLinearKernel + (RowWiseTorchScaledMMLinearKernel, GroupShape.PER_TOKEN, True), + (RowWiseTorchScaledMMLinearKernel, GroupShape.PER_TOKEN, False), + # Per-token with ChannelWiseTorchScaledMMLinearKernel + (ChannelWiseTorchScaledMMLinearKernel, GroupShape.PER_TOKEN, True), + (ChannelWiseTorchScaledMMLinearKernel, GroupShape.PER_TOKEN, False), + # Blockwise (no kernel abstraction) + (None, GroupShape(1, 128), True), +] + class TestModel(torch.nn.Module): def __init__( self, hidden_size: int, eps: float, + force_kernel: FP8ScaledMMLinearKernel | None, group_shape: GroupShape, - use_aiter: bool = False, - cuda_force_torch: bool = False, - use_aiter_quant_op: bool = True, + use_aiter_fusion: bool = False, + use_aiter_quant: bool = False, *args, **kwargs, ): super().__init__(*args, **kwargs) - self.use_aiter = use_aiter - self.use_aiter_quant_op = use_aiter_quant_op - self.cuda_force_torch = cuda_force_torch + self.fp8_linear_layers: list[torch.nn.Module] self.group_shape = group_shape - self.enable_quant_fp8_custom_op = None # Will be set later if applicable - + self.use_aiter_quant_op = use_aiter_quant + self.use_aiter_fusion = use_aiter_fusion self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)] + self.enable_rms_norm_custom_op = self.norm[0].enabled() - # Setup quantization scale descriptor - static = group_shape == GroupShape.PER_TENSOR and not use_aiter - quant_scale = ScaleDesc(torch.float32, static, group_shape) - self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True) + # Determine if blockwise based on group_shape + is_blockwise = group_shape.is_per_group() - # Setup scales - if static: - self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] + if is_blockwise: + self._init_blockwise( + hidden_size, group_shape, use_aiter_fusion, use_aiter_quant + ) else: - self.scale = [None for _ in range(3)] + self._init_nonblockwise( + hidden_size, group_shape, force_kernel, use_aiter_quant + ) - # Setup weights + def _init_nonblockwise( + self, + hidden_size: int, + group_shape: GroupShape, + force_kernel: FP8ScaledMMLinearKernel | None, + use_aiter_quant: bool, + ): + """Initialize non-blockwise (per-tensor/per-token) FP8 layers.""" + is_static = group_shape == GroupShape.PER_TENSOR + act_quant_scale_desc = ScaleDesc(torch.float32, is_static, group_shape) + w_quant_scale_desc = ScaleDesc(torch.float32, True, group_shape) + self.activation_quant_key = QuantKey( + dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True + ) + self.weight_quant_key = QuantKey( + dtype=FP8_DTYPE, scale=w_quant_scale_desc, symmetric=True + ) + + # Setup weight scales + wscale_shape = (1,) if group_shape.is_per_tensor() else (hidden_size, 1) + self.wscale = [torch.rand(wscale_shape, dtype=torch.float32) for _ in range(3)] + + self.act_scale = ( + [torch.rand(1, dtype=torch.float32) for _ in range(3)] + if is_static + else [None for _ in range(3)] + ) + + # Initialize weights (transposed for non-blockwise) + self.w = [ + torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() + for _ in range(3) + ] + + # Setup FP8 linear layers with kernel abstraction + self.fp8_linear_layers = [ + TestFP8Layer( + self.activation_quant_key, + self.weight_quant_key, + self.w[i], + self.wscale[i], + input_scale=self.act_scale[i], + force_kernel=force_kernel, + ) + for i in range(3) + ] + + # Enable aiter quantization if requested + for layer in self.fp8_linear_layers: + layer.kernel.quant_fp8.use_aiter = use_aiter_quant + + self.enable_quant_fp8_custom_op = self.fp8_linear_layers[ + 0 + ].is_quant_fp8_enabled() + + def _init_blockwise( + self, + hidden_size: int, + group_shape: GroupShape, + use_aiter_fusion: bool, + use_aiter_quant: bool, + ): + """Initialize blockwise FP8 layers.""" + act_quant_scale_desc = ScaleDesc(torch.float32, False, group_shape) + self.activation_quant_key = QuantKey( + dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True + ) + + # Setup weight scales (for blockwise quantization) + # Use aiter block size if aiter fusion is enabled + scale_size = ( + (hidden_size + 128 - 1) // 128 + if use_aiter_fusion + else hidden_size // group_shape[1] + ) + wscale_shape = (scale_size, scale_size) + self.wscale = [torch.rand(wscale_shape, dtype=torch.float32) for _ in range(3)] + + # Initialize weights (transposed if using aiter, otherwise not) self.w = [ torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE) for _ in range(3) ] - if not group_shape.is_per_group() or use_aiter: - self.w = [self.w[0].t() for _ in range(3)] + if use_aiter_fusion: + self.w = [w.t() for w in self.w] - # Setup weight scales - if group_shape.is_per_group(): - scale_size = ( - (hidden_size + 128 - 1) // 128 - if use_aiter - else hidden_size // group_shape[1] - ) - wscale_shape: tuple[int, ...] = (scale_size, scale_size) - else: - wscale_shape = (1,) - self.wscale = [torch.rand(wscale_shape, dtype=torch.float32) for _ in range(3)] - - # Setup FP8 linear operation - is_per_group = group_shape.is_per_group() - if is_per_group and use_aiter: - self.fp8_linear = W8A8BlockFp8LinearOp( - weight_group_shape=GroupShape(128, 128), - act_quant_group_shape=group_shape, - use_aiter_and_is_supported=use_aiter_quant_op, - ) - # AITER blockwise doesn't use enable_quant_fp8_custom_op - elif is_per_group: - self.fp8_linear = W8A8BlockFp8LinearOp( - weight_group_shape=GroupShape(group_shape[1], group_shape[1]), - act_quant_group_shape=group_shape, + self.fp8_linear_layers = [ + TestBlockFP8Layer( + group_shape=group_shape, + weight=self.w[i], + weight_scale=self.wscale[i], + input_scale=None, # Dynamic quantization for blockwise cutlass_block_fp8_supported=cutlass_block_fp8_supported(), - use_aiter_and_is_supported=False, + use_aiter_and_is_supported=use_aiter_quant, ) - self.enable_quant_fp8_custom_op = self.fp8_linear.input_quant_op.enabled() - elif use_aiter: - self.fp8_linear = Fp8LinearOp( - act_quant_static=False, - act_quant_group_shape=group_shape, - ) - self.fp8_linear.quant_fp8.use_aiter = use_aiter_quant_op - self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() - else: - with override_cutlass_fp8_supported(not cuda_force_torch): - self.fp8_linear = Fp8LinearOp( - act_quant_static=static, - act_quant_group_shape=group_shape, - ) - self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() + for i in range(3) + ] - self.enable_rms_norm_custom_op = self.norm[0].enabled() + self.enable_quant_fp8_custom_op = ( + False + if use_aiter_quant + else self.fp8_linear_layers[0].linear_op.input_quant_op.enabled() + ) def forward(self, x): # avoid having graph input be an arg to a pattern directly x = resid = torch.relu(x) y = self.norm[0](x) - x2 = self.fp8_linear.apply( - y, self.w[0], self.wscale[0], input_scale=self.scale[0] - ) + x2 = self.fp8_linear_layers[0](y) # make sure resid is used for replacement to work y2, resid = self.norm[1](x2, resid) - x3 = self.fp8_linear.apply( - y2, self.w[1], self.wscale[1], input_scale=self.scale[1] - ) + x3 = self.fp8_linear_layers[1](y2) y3, resid = self.norm[2](x3, resid) # use resid here - x4 = self.fp8_linear.apply( - y3, self.w[2], self.wscale[2], input_scale=self.scale[2] - ) + x4 = self.fp8_linear_layers[2](y3) y4, resid = self.norm[3](x4, resid) # use resid here return y4 def ops_in_model_before(self): - if ( - self.use_aiter - and self.group_shape.is_per_group() - and current_platform.is_fp8_fnuz() - ): - return [rocm_aiter_ops.get_group_quant_op()] - if self.use_aiter and self.group_shape.is_per_group(): - return [torch.ops.vllm.triton_per_token_group_quant_fp8.default] - if self.use_aiter and self.use_aiter_quant_op: - return [rocm_aiter_ops.get_per_token_quant_op()] - if self.use_aiter: - return [QUANT_OPS[self.quant_key]] - if self.enable_quant_fp8_custom_op: - return [QUANT_OPS[self.quant_key]] - return [torch.ops.aten.reciprocal] + if self.group_shape.is_per_group(): + # Blockwise path + if self.use_aiter_fusion and self.use_aiter_quant_op: + return [rocm_aiter_ops.get_group_quant_op()] + if self.use_aiter_fusion: + return [torch.ops.vllm.triton_per_token_group_quant_fp8.default] + else: + if self.use_aiter_quant_op: + return [rocm_aiter_ops.get_per_token_quant_op()] + + # Common path + return ( + [QUANT_OPS[self.activation_quant_key]] + if self.enable_quant_fp8_custom_op + else [torch.ops.aten.reciprocal] + ) def ops_in_model_after(self): - if self.use_aiter and self.group_shape.is_per_group(): - from vllm.compilation.rocm_aiter_fusion import ( - AiterFusedAddRMSFp8GroupQuantPattern, - AiterRMSFp8GroupQuantPattern, - ) + if self.use_aiter_fusion: + if self.group_shape.is_per_group(): + # Blockwise aiter fusion + from vllm.compilation.rocm_aiter_fusion import ( + AiterFusedAddRMSFp8GroupQuantPattern, + AiterRMSFp8GroupQuantPattern, + ) - return [ - AiterFusedAddRMSFp8GroupQuantPattern.FUSED_OP, - AiterRMSFp8GroupQuantPattern.FUSED_OP, - ] - if self.use_aiter: - from vllm.compilation.rocm_aiter_fusion import ( - AiterFusedAddRMSNormDynamicQuantPattern, - AiterRMSNormDynamicQuantPattern, - ) + return [ + AiterFusedAddRMSFp8GroupQuantPattern.FUSED_OP, + AiterRMSFp8GroupQuantPattern.FUSED_OP, + ] + else: + # Per-token aiter fusion + from vllm.compilation.rocm_aiter_fusion import ( + AiterFusedAddRMSNormDynamicQuantPattern, + AiterRMSNormDynamicQuantPattern, + ) - return [ - AiterFusedAddRMSNormDynamicQuantPattern.FUSED_OP, - AiterRMSNormDynamicQuantPattern.FUSED_OP, - ] + return [ + AiterFusedAddRMSNormDynamicQuantPattern.FUSED_OP, + AiterRMSNormDynamicQuantPattern.FUSED_OP, + ] + + # Regular fusion return [ - FUSED_OPS[FusedRMSQuantKey(self.quant_key, True)], - FUSED_OPS[FusedRMSQuantKey(self.quant_key, False)], + FUSED_OPS[FusedRMSQuantKey(self.activation_quant_key, True)], + FUSED_OPS[FusedRMSQuantKey(self.activation_quant_key, False)], ] def ops_in_model_before_partial(self): @@ -206,14 +323,6 @@ class TestModel(torch.nn.Module): ) -GROUP_SHAPES = [ - GroupShape.PER_TOKEN, - GroupShape.PER_TENSOR, - GroupShape(1, 128), - GroupShape(1, 64), -] - - def _run_fusion_test( model, fusion_pass, @@ -259,14 +368,9 @@ def _run_fusion_test( @pytest.mark.parametrize("hidden_size", [256]) @pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) -@pytest.mark.parametrize("group_shape", GROUP_SHAPES) +@pytest.mark.parametrize("kernel_groupshape", KERNEL_GROUPSHAPE_COMBINATIONS) @pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False]) @pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False]) -# cuda_force_torch used to test torch code path on platforms that -# cutlass_fp8_supported() == True. -@pytest.mark.parametrize( - "cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True] -) @pytest.mark.skipif( not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm" ) @@ -275,11 +379,12 @@ def test_fusion_rmsnorm_quant( hidden_size, num_tokens, eps, - group_shape, + kernel_groupshape, enable_rms_norm_custom_op, enable_quant_fp8_custom_op, - cuda_force_torch, ): + force_kernel, group_shape = kernel_groupshape + if not enable_quant_fp8_custom_op and group_shape.is_per_group(): pytest.skip("Unsupported unwrapped quant fp8 op for blockwise quantization") @@ -310,15 +415,16 @@ def test_fusion_rmsnorm_quant( torch.set_default_device("cuda") torch.set_default_dtype(dtype) torch.manual_seed(1) - maybe_create_device_identity() fusion_pass = RMSNormQuantFusionPass(vllm_config) + model = TestModel( hidden_size=hidden_size, eps=eps, + force_kernel=force_kernel, group_shape=group_shape, - use_aiter=False, - cuda_force_torch=cuda_force_torch, + use_aiter_fusion=False, + use_aiter_quant=False, ) backend, _ = _run_fusion_test( @@ -339,19 +445,12 @@ def test_fusion_rmsnorm_quant( assert n_add_nodes(backend.graph_post_pass) == 2 -GROUP_SHAPE_QUANT_OPS_MATCHS = [ - (GroupShape.PER_TOKEN, True), - (GroupShape.PER_TOKEN, False), - (GroupShape(1, 128), True), -] - - @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("hidden_size", [256]) @pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize( - "group_shape, use_aiter_quant_op", GROUP_SHAPE_QUANT_OPS_MATCHS + "kernel_groupshape_quant", AITER_KERNEL_GROUPSHAPE_COMBINATIONS ) @pytest.mark.skipif( (not current_platform.is_rocm() or not IS_AITER_FOUND), @@ -362,10 +461,10 @@ def test_aiter_fusion_rmsnorm_quant( hidden_size: int, num_tokens: int, eps: float, - group_shape: GroupShape, - use_aiter_quant_op: bool, + kernel_groupshape_quant: tuple, monkeypatch: pytest.MonkeyPatch, ): + force_kernel, group_shape, use_aiter_quant_op = kernel_groupshape_quant vllm_config = VllmConfig( model_config=ModelConfig(dtype=dtype), compilation_config=CompilationConfig( @@ -379,20 +478,22 @@ def test_aiter_fusion_rmsnorm_quant( from vllm.compilation.rocm_aiter_fusion import RocmAiterRMSNormFusionPass m.setenv("VLLM_ROCM_USE_AITER", "1") + rocm_aiter_ops.refresh_env_variables() torch.set_default_device("cuda") torch.set_default_dtype(dtype) torch.manual_seed(1) - maybe_create_device_identity() fusion_pass = RocmAiterRMSNormFusionPass(vllm_config) + model = TestModel( hidden_size=hidden_size, eps=eps, + force_kernel=force_kernel, group_shape=group_shape, - use_aiter=True, - use_aiter_quant_op=use_aiter_quant_op, + use_aiter_fusion=True, # Always use aiter fusion ops in aiter test + use_aiter_quant=use_aiter_quant_op, # Toggle aiter quantization ) _run_fusion_test( diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index db95dff5e0fc7..dbe35189ce1fc 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -34,11 +34,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( kFp8StaticTensorSym, kNvfp4Quant, ) -from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer from vllm.v1.kv_cache_interface import AttentionSpec +from ..utils import TestFP8Layer + FP8_DTYPE = current_platform.fp8_dtype() FP4_DTYPE = torch.uint8 @@ -171,11 +172,6 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.fp8_linear = Fp8LinearOp( - act_quant_static=self.quant_key.scale.static, - act_quant_group_shape=self.quant_key.scale.group_shape, - ) - hidden_size = self.num_qo_heads * self.head_size self.w = kwargs.get( "w", @@ -187,16 +183,18 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel): "scale": torch.tensor([1.0], dtype=torch.float32, device=self.device), }, ) + self.fp8_linear = TestFP8Layer( + self.quant_key, + self.quant_key, + self.w["weight"], + self.w["wscale"], + self.w["scale"], + ) def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): """Forward pass that creates the pattern to be fused.""" attn_output = self.attn(q, k, v) - return self.fp8_linear.apply( - input=attn_output, - weight=self.w["weight"], - weight_scale=self.w["wscale"], - input_scale=self.w["scale"], - ) + return self.fp8_linear(attn_output) class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel): diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index eb0dee8d4e399..d0ff6dd4bfc72 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -31,13 +31,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( kFp8StaticTensorSym, kNvfp4Quant, ) -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, - maybe_create_device_identity, -) from vllm.platforms import current_platform -from ..utils import override_cutlass_fp8_supported +from ..utils import TestFP8Layer, override_cutlass_fp8_supported from .backend import TestBackend FP8_DTYPE = current_platform.fp8_dtype() @@ -49,25 +45,30 @@ def is_nvfp4_supported(): class TestSiluMulFp8QuantModel(torch.nn.Module): + quant_key = kFp8StaticTensorSym + def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs): super().__init__() self.silu_and_mul = SiluAndMul() - self.wscale = torch.rand(1, dtype=torch.float32) - self.scale = torch.rand(1, dtype=torch.float32) - - self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() + self.weight_scale = torch.rand(1, dtype=torch.float32) + self.input_scale = torch.rand(1, dtype=torch.float32) + self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() with override_cutlass_fp8_supported(not cuda_force_torch): - self.fp8_linear = Fp8LinearOp( - act_quant_static=True, - act_quant_group_shape=GroupShape.PER_TENSOR, + self.fp8_linear = TestFP8Layer( + self.quant_key, + self.quant_key, + self.weight, + self.weight_scale, + self.input_scale, ) + self.enable_silu_mul_custom_op = self.silu_and_mul.enabled() - self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() + self.enable_quant_fp8_custom_op = self.fp8_linear.is_quant_fp8_enabled() def forward(self, x): y = self.silu_and_mul(x) - x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale) + x2 = self.fp8_linear(y) return x2 def ops_in_model_before(self): @@ -198,7 +199,6 @@ def test_fusion_silu_and_mul_quant( torch.set_default_device("cuda") torch.set_default_dtype(dtype) - maybe_create_device_identity() x = torch.rand(num_tokens, hidden_size * 2) diff --git a/tests/utils.py b/tests/utils.py index 1b338e93182a5..6f1d94e2a3dce 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -42,6 +42,17 @@ from vllm.distributed import ( ) from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.cli.serve import ServeSubcommand +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + init_fp8_linear_kernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + FP8ScaledMMLinearKernel, +) +from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + QuantKey, +) from vllm.model_executor.model_loader import get_model_loader from vllm.platforms import current_platform from vllm.tokenizers import get_tokenizer @@ -1311,3 +1322,96 @@ def flat_product(*iterables: Iterable[Any]): for element in itertools.product(*iterables): normalized = (e if isinstance(e, tuple) else (e,) for e in element) yield tuple(itertools.chain(*normalized)) + + +class TestFP8Layer(torch.nn.Module): + """ + Test helper class for evaluating FP8 linear operations with quantization. + + It supports configurable activation and weight quantization parameters, + and provides a forward method that applies the FP8 linear transformation + with optional bias. + + Args: + activation_quant_key (QuantKey): Key for activation quantization configuration. + weight_quant_key (QuantKey): Key for weight quantization configuration. + weight (torch.Tensor): Weight tensor for linear transformation. + weight_scale (torch.Tensor): Per-tensor or per-group scale for weights. + input_scale (torch.Tensor, optional): Scale tensor for input quantization. + Defaults to None. + out_dtype (torch.dtype, optional): Output tensor data type. + Defaults to torch.get_default_dtype(). + """ + + def __init__( + self, + activation_quant_key: QuantKey, + weight_quant_key: QuantKey, + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: torch.Tensor | None = None, + out_dtype: torch.dtype | None = None, + force_kernel: FP8ScaledMMLinearKernel | None = None, + ): + super().__init__() + self.weight_scale = weight_scale + self.weight = weight + self.input_scale = input_scale + self.input_scale_ub = None + out_dtype = torch.get_default_dtype() if out_dtype is None else out_dtype + self.kernel = init_fp8_linear_kernel( + activation_quant_key=activation_quant_key, + weight_quant_key=weight_quant_key, + out_dtype=out_dtype, + force_kernel=force_kernel, + ) + + def is_quant_fp8_enabled(self) -> bool: + return self.kernel.quant_fp8.enabled() + + def forward( + self, y: torch.Tensor, bias: torch.Tensor | None = None + ) -> torch.Tensor: + return self.kernel.apply_weights(self, y, bias) + + +class TestBlockFP8Layer: + """ + Test wrapper for W8A8BlockFp8LinearOp to match TestFP8Layer interface. + + This is a workaround until W8A8BlockFp8LinearOp implements + ScaledMMLinearKernel (i.e., a kernel abstraction for blockwise quantization). + """ + + def __init__( + self, + group_shape: GroupShape, + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: torch.Tensor | None = None, + cutlass_block_fp8_supported: bool = False, + use_aiter_and_is_supported: bool = False, + ): + self.linear_op = W8A8BlockFp8LinearOp( + weight_group_shape=GroupShape(group_shape[1], group_shape[1]), + act_quant_group_shape=group_shape, + cutlass_block_fp8_supported=cutlass_block_fp8_supported, + use_aiter_and_is_supported=use_aiter_and_is_supported, + ) + self.weight = weight + self.weight_scale = weight_scale + self.input_scale = input_scale + + def __call__( + self, y: torch.Tensor, bias: torch.Tensor | None = None + ) -> torch.Tensor: + return self.linear_op.apply( + input=y, + weight=self.weight, + weight_scale=self.weight_scale, + input_scale=self.input_scale, + bias=bias, + ) + + def is_quant_fp8_enabled(self) -> bool: + return self.linear_op.input_quant_op.enabled() diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 758a54c10605a..205abe942bd01 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -8,9 +8,13 @@ from compressed_tensors.quantization import QuantizationArgs, QuantizationStrate from torch.nn import Parameter from vllm._aiter_ops import rocm_aiter_ops +from vllm.logger import init_logger from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, ) +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + init_fp8_linear_kernel, +) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, create_fp8_input_scale, @@ -22,11 +26,14 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( process_fp8_weight_tensor_strategy, validate_fp8_block_shape, ) -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + kFp8DynamicTokenSym, + kFp8StaticTensorSym, + kFp8StaticTokenSym, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, cutlass_block_fp8_supported, - maybe_create_device_identity, ) from vllm.model_executor.parameter import ( BlockQuantScaleParameter, @@ -42,6 +49,18 @@ strategy_to_parameter_type = { QuantizationStrategy.TENSOR: PerTensorScaleParameter, } +STATIC_QUANT = True +DYNAMIC_QUANT = False +activation_quant_key_mapping = { + STATIC_QUANT: kFp8StaticTensorSym, + DYNAMIC_QUANT: kFp8DynamicTokenSym, +} +weight_quant_key_mapping = { + QuantizationStrategy.CHANNEL: kFp8StaticTokenSym, + QuantizationStrategy.TENSOR: kFp8StaticTensorSym, +} +logger = init_logger(__name__) + class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool): @@ -49,22 +68,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): self.strategy = weight_quant.strategy self.out_dtype = torch.get_default_dtype() self.is_static_input_scheme = is_static_input_scheme - self.weight_block_size = self.weight_quant.block_structure - if self.weight_block_size is not None: - self.act_q_group_shape = GroupShape(1, self.weight_block_size[0]) - else: - self.act_q_group_shape = ( - GroupShape.PER_TENSOR - if is_static_input_scheme - else GroupShape.PER_TOKEN - ) - - self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() - self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled() if self.weight_block_size is not None: + self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() + self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled() assert not self.is_static_input_scheme + self.act_q_group_shape = GroupShape(1, self.weight_block_size[0]) self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( weight_group_shape=GroupShape(*self.weight_block_size), act_quant_group_shape=self.act_q_group_shape, @@ -72,9 +82,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): use_aiter_and_is_supported=self.use_aiter_and_is_supported, ) else: - self.fp8_linear = Fp8LinearOp( - act_quant_static=self.is_static_input_scheme, - act_quant_group_shape=self.act_q_group_shape, + activation_quant_key = activation_quant_key_mapping[is_static_input_scheme] + weight_quant_key = weight_quant_key_mapping[self.strategy] + self.fp8_linear = init_fp8_linear_kernel( + activation_quant_key=activation_quant_key, + weight_quant_key=weight_quant_key, + out_dtype=self.out_dtype, + module_name=self.__class__.__name__, ) @classmethod @@ -93,8 +107,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): weight_loader: Callable, **kwargs, ): - maybe_create_device_identity() - output_size_per_partition = sum(output_partition_sizes) layer.logical_widths = output_partition_sizes layer.weight_block_size = None @@ -134,6 +146,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): input_scale = create_fp8_input_scale(output_partition_sizes, weight_loader) layer.register_parameter("input_scale", input_scale) + layer.input_scale_ub = None + def process_weights_after_loading(self, layer) -> None: if self.strategy == QuantizationStrategy.TENSOR: weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy( @@ -190,11 +204,4 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): bias=bias, ) - return self.fp8_linear.apply( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - out_dtype=self.out_dtype, - input_scale=layer.input_scale, - bias=bias, - ) + return self.fp8_linear.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py index 6fd0a6a1c822c..652feb1964575 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -11,8 +11,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( - ScaledMMLinearLayerConfig, - choose_scaled_mm_linear_kernel, + init_int8_linear_kernel, ) from vllm.model_executor.parameter import ( BasevLLMParameter, @@ -25,8 +24,6 @@ logger = init_logger(__name__) class CompressedTensorsW8A8Int8(CompressedTensorsScheme): - _kernel_backends_being_used: set[str] = set() - def __init__( self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool ): @@ -50,18 +47,13 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): ): layer.logical_widths = output_partition_sizes - scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig( + self.kernel = init_int8_linear_kernel( is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL), is_static_input_scheme=self.is_static_input_scheme, input_symmetric=self.input_symmetric, + module_name=self.__class__.__name__, ) - kernel_type = choose_scaled_mm_linear_kernel(scaled_mm_linear_kernel_config) - - if kernel_type.__name__ not in self._kernel_backends_being_used: - logger.info("Using %s for CompressedTensorsW8A8Int8", kernel_type.__name__) - self._kernel_backends_being_used.add(kernel_type.__name__) - # WEIGHT weight = ModelWeightParameter( data=torch.empty( @@ -90,12 +82,12 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): layer.register_parameter("weight_scale", weight_scale) # INPUT SCALE + input_zero_point = None + input_scale = None if self.is_static_input_scheme: input_scale = BasevLLMParameter( data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader ) - layer.register_parameter("input_scale", input_scale) - if not self.input_symmetric: # Note: compressed-tensors stores the zp using the same dtype # as the weights @@ -103,16 +95,11 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): input_zero_point = BasevLLMParameter( data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader ) - layer.register_parameter("input_zero_point", input_zero_point) - self.kernel = kernel_type( - c=scaled_mm_linear_kernel_config, - w_q_param_name="weight", - w_s_param_name="weight_scale", - i_s_param_name="input_scale", - i_zp_param_name="input_zero_point", - azp_adj_param_name="azp_adj", - ) + layer.register_parameter("input_zero_point", input_zero_point) + layer.register_parameter("input_scale", input_scale) + if not hasattr(layer, "azp_adj"): + layer.register_parameter("azp_adj", None) # Checkpoints are serialized in compressed-tensors format, which is # different from the format the kernel may want. Handle repacking here. diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 6ba18e59e4d54..45d2e4e338190 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -18,17 +18,19 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + init_fp8_linear_kernel, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin, ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, is_layer_skipped, + kFp8DynamicTokenSym, + kFp8StaticTokenSym, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, - maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz, ) from vllm.model_executor.parameter import ( @@ -91,10 +93,13 @@ class FBGEMMFp8Config(QuantizationConfig): class FBGEMMFp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: FBGEMMFp8Config): self.quant_config = quant_config - self.fp8_linear = Fp8LinearOp( - act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN - ) self.out_dtype = torch.get_default_dtype() + self.fp8_linear = init_fp8_linear_kernel( + activation_quant_key=kFp8DynamicTokenSym, + weight_quant_key=kFp8StaticTokenSym, + out_dtype=torch.get_default_dtype(), + module_name=self.__class__.__name__, + ) def create_weights( self, @@ -106,7 +111,6 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): params_dtype: torch.dtype, **extra_weight_attrs, ): - maybe_create_device_identity() weight_loader = extra_weight_attrs.get("weight_loader") del input_size, output_size output_size_per_partition = sum(output_partition_sizes) @@ -184,12 +188,4 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): bias=bias, ) - return self.fp8_linear.apply( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - out_dtype=self.out_dtype, - input_scale=None, - input_scale_ub=layer.input_scale_ub, - bias=bias, - ) + return self.fp8_linear.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 9da19c082dc27..530844a15614d 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -45,6 +45,9 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + init_fp8_linear_kernel, +) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( FlashinferMoeBackend, @@ -78,13 +81,14 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, is_layer_skipped, + kFp8DynamicTensorSym, + kFp8DynamicTokenSym, + kFp8StaticTensorSym, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported, cutlass_fp8_supported, - maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize, ) @@ -431,8 +435,13 @@ class Fp8LinearMethod(LinearMethodBase): # Use per-token quantization for better perf if dynamic and cutlass if not self.act_q_static and cutlass_fp8_supported(): self.act_q_group_shape = GroupShape.PER_TOKEN + self.activation_quant_key = kFp8DynamicTokenSym + elif self.act_q_static: + self.act_q_group_shape = GroupShape.PER_TENSOR + self.activation_quant_key = kFp8StaticTensorSym else: self.act_q_group_shape = GroupShape.PER_TENSOR + self.activation_quant_key = kFp8DynamicTensorSym if self.block_quant: assert not self.act_q_static @@ -444,9 +453,11 @@ class Fp8LinearMethod(LinearMethodBase): use_aiter_and_is_supported=self.use_aiter_and_is_supported, ) else: - self.fp8_linear = Fp8LinearOp( - act_quant_static=self.act_q_static, - act_quant_group_shape=self.act_q_group_shape, + self.fp8_linear = init_fp8_linear_kernel( + activation_quant_key=self.activation_quant_key, + weight_quant_key=kFp8StaticTensorSym, + out_dtype=torch.get_default_dtype(), + module_name=self.__class__.__name__, ) def create_weights( @@ -459,8 +470,6 @@ class Fp8LinearMethod(LinearMethodBase): params_dtype: torch.dtype, **extra_weight_attrs, ): - maybe_create_device_identity() - output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") layer.logical_widths = output_partition_sizes @@ -525,6 +534,7 @@ class Fp8LinearMethod(LinearMethodBase): weight_loader=patched_weight_loader, ) layer.register_parameter("weight", weight) + layer.input_scale_ub = None # If checkpoint is serialized fp8, load them. # Otherwise, wait until process_weights_after_loading. @@ -699,14 +709,7 @@ class Fp8LinearMethod(LinearMethodBase): bias=bias, ) - return self.fp8_linear.apply( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - out_dtype=self.out_dtype, - input_scale=layer.input_scale, - bias=bias, - ) + return self.fp8_linear.apply_weights(layer, x, bias) class Fp8MoEMethod(FusedMoEMethodBase): diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py index 7be220f7a3734..e82f39c4b895b 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py @@ -2,48 +2,73 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod +from collections.abc import Callable, Sequence from dataclasses import dataclass +from typing import Generic, TypeVar import torch +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, +) +from vllm.platforms import current_platform + @dataclass class ScaledMMLinearLayerConfig: - is_channelwise: bool + pass + + +@dataclass +class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): + # TODO: Chnage to QuantKey like FP8ScaledMMLinearLayerConfig is_static_input_scheme: bool + is_channelwise: bool input_symmetric: bool -class ScaledMMLinearKernel(ABC): +@dataclass +class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): + weight_quant_key: QuantKey + activation_quant_key: QuantKey + out_dtype: torch.dtype | None + + +_FP8ParamsT = tuple[ + torch.Tensor, # weight + torch.Tensor, # weight_scale + torch.Tensor | None, # input_scale, + torch.Tensor | None, # input_scale_ub, +] +_Int8ParamsT = tuple[ + torch.Tensor, # weight + torch.Tensor, # weight_scale + torch.Tensor | None, # input_scale, + torch.Tensor | None, # input_zp + torch.Tensor | None, # azp_adj +] + +_ParamsT = TypeVar("_ParamsT", _Int8ParamsT, _FP8ParamsT) +_ConfigT = TypeVar("_ConfigT", bound=ScaledMMLinearLayerConfig) + + +class ScaledMMLinearKernel(Generic[_ConfigT, _ParamsT], ABC): @classmethod @abstractmethod - def is_supported( - cls, compute_capability: int | None = None - ) -> tuple[bool, str | None]: + def is_platform_supported(cls) -> tuple[bool, str | None]: raise NotImplementedError @classmethod @abstractmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + def can_implement(cls, c: _ConfigT) -> tuple[bool, str | None]: raise NotImplementedError - def __init__( - self, - c: ScaledMMLinearLayerConfig, - w_q_param_name: str, - w_s_param_name: str, - i_s_param_name: str, - i_zp_param_name: str, - azp_adj_param_name: str, - ) -> None: + def __init__(self, c: _ConfigT, layer_param_names: Sequence[str]) -> None: assert self.can_implement(c) - assert self.is_supported() + assert self.is_platform_supported() self.config = c - self.w_q_name = w_q_param_name - self.w_s_name = w_s_param_name - self.i_s_name = i_s_param_name - self.i_zp_name = i_zp_param_name - self.azp_adj_name = azp_adj_param_name + self.layer_param_names = layer_param_names @abstractmethod def process_weights_after_loading(self, layer: torch.nn.Module) -> None: @@ -58,19 +83,109 @@ class ScaledMMLinearKernel(ABC): ) -> torch.Tensor: raise NotImplementedError - def _get_weight_params( - self, layer: torch.nn.Module - ) -> tuple[ - torch.Tensor, # weight - torch.Tensor, # weight_scale - torch.Tensor | None, # input_scale, - torch.Tensor | None, # input_zp - torch.Tensor | None, # azp_adj - ]: - return ( - getattr(layer, self.w_q_name), - getattr(layer, self.w_s_name), - getattr(layer, self.i_s_name), - getattr(layer, self.i_zp_name), - getattr(layer, self.azp_adj_name), + # return a covariant type in the subclass + @abstractmethod + def _get_layer_params(self, layer) -> _ParamsT: + raise NotImplementedError + + @classmethod + def get_min_capability(cls) -> int: + raise NotImplementedError + + +class FP8ScaledMMLinearKernel( + ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig, _FP8ParamsT], ABC +): + def __init__( + self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str] + ) -> None: + act_scale_descriptor = c.activation_quant_key.scale + self.quant_fp8 = QuantFP8( + static=act_scale_descriptor.static, + group_shape=act_scale_descriptor.group_shape, + num_token_padding=self.get_ouput_padding(), + ) + self.fp8_dtype = current_platform.fp8_dtype() + super().__init__(c, layer_param_names) + + @classmethod + def get_min_capability(cls) -> int: + # lovelace and up + return 89 + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + pass + + def _get_layer_params(self, layer) -> _FP8ParamsT: + w, w_s, x_s, x_s_ub = self.layer_param_names + return ( + getattr(layer, w), + getattr(layer, w_s), + getattr(layer, x_s), + getattr(layer, x_s_ub), + ) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + scaled_mm_func = self.get_scaled_mm_func() + quant_fp8 = self.quant_fp8 + fp8_dtype = self.fp8_dtype + maybe_out_dtype = self.config.out_dtype + w, w_s, x_s, x_s_ub = self._get_layer_params(layer) + + # ops.scaled_fp8_quant supports both dynamic and static quant. + # If dynamic, layer.input_scale is None and x_s computed from x. + # If static, layer.input_scale is scalar and x_s is input_scale. + # View input as 2D matrix for fp8 methods + x_2d = x.view(-1, x.shape[-1]) + output_shape = [*x.shape[:-1], w.shape[1]] + out_dtype = x.dtype if maybe_out_dtype is None else maybe_out_dtype + + # If input not quantized + # TODO(luka) remove this path if not used anymore + x_2d_q = x_2d + if x.dtype != fp8_dtype: + x_2d_q, x_s = quant_fp8( + x_2d, + x_s, + x_s_ub, + ) + return scaled_mm_func( + A=x_2d_q, + B=w, + out_dtype=out_dtype, + As=x_s, + Bs=w_s, + bias=bias, + output_shape=output_shape, + ) + + @abstractmethod + def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]: + raise NotImplementedError + + @abstractmethod + def get_ouput_padding(self) -> int | None: + raise NotImplementedError + + +class Int8ScaledMMLinearKernel( + ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, _Int8ParamsT], ABC +): + @classmethod + def get_min_capability(cls) -> int: + return 75 + + def _get_layer_params(self, layer) -> _Int8ParamsT: + w_q, w_s, i_s, i_zp, azp_adj = self.layer_param_names + return ( + getattr(layer, w_q), + getattr(layer, w_s), + getattr(layer, i_s), + getattr(layer, i_zp), + getattr(layer, azp_adj), ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 20d050d387d49..24d85b93db2b1 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -2,7 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os +from typing import TypeVar +import torch + +from vllm.logger import init_logger from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import ( AiterScaledMMLinearKernel, ) @@ -10,9 +14,25 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import ( CPUScaledMMLinearKernel, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( + CutlassFP8ScaledMMLinearKernel, CutlassScaledMMLinearKernel, ) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer import ( + FlashInferScaledMMLinearKernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import ( + ChannelWiseTorchScaledMMLinearKernel, + PerTensorTorchScaledMMLinearKernel, + RowWiseTorchScaledMMLinearKernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import ( + ROCmScaledMMLinearKernel, +) from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + FP8ScaledMMLinearKernel, + FP8ScaledMMLinearLayerConfig, + Int8ScaledMMLinearKernel, + Int8ScaledMMLinearLayerConfig, ScaledMMLinearKernel, ScaledMMLinearLayerConfig, ) @@ -22,60 +42,206 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import ( XLAScaledMMLinearKernel, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey from vllm.platforms import PlatformEnum, current_platform +logger = init_logger(__name__) + # in priority/performance order (when available) -_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = { +_POSSIBLE_INT8_KERNELS: dict[PlatformEnum, list[type[Int8ScaledMMLinearKernel]]] = { PlatformEnum.CPU: [CPUScaledMMLinearKernel], PlatformEnum.CUDA: [CutlassScaledMMLinearKernel, TritonScaledMMLinearKernel], PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel], PlatformEnum.TPU: [XLAScaledMMLinearKernel], } +# in priority/performance order (when available) +_POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] = { + PlatformEnum.CUDA: [ + FlashInferScaledMMLinearKernel, + CutlassFP8ScaledMMLinearKernel, + PerTensorTorchScaledMMLinearKernel, + ChannelWiseTorchScaledMMLinearKernel, + ], + PlatformEnum.ROCM: [ + ROCmScaledMMLinearKernel, + PerTensorTorchScaledMMLinearKernel, + RowWiseTorchScaledMMLinearKernel, + ChannelWiseTorchScaledMMLinearKernel, + ], + PlatformEnum.CPU: [ + PerTensorTorchScaledMMLinearKernel, + ChannelWiseTorchScaledMMLinearKernel, + ], +} + +_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel) +_KernelConfigT = TypeVar("_KernelConfigT", bound=ScaledMMLinearLayerConfig) + + +def is_supported_and_can_implement_kernel( + kernel: type[_KernelT], config: _KernelConfigT, compute_capability: int | None +) -> tuple[bool, str]: + if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","): + return False, f" {kernel.__name__} disabled by environment variable" + + platform_supported, requires_platform = kernel.is_platform_supported() + if not platform_supported: + return ( + False, + f"{kernel.__name__} is not supported as it requires {requires_platform}.", + ) + + if compute_capability is None: + _cc = current_platform.get_device_capability() + if _cc is not None: + compute_capability = _cc[0] * 10 + _cc[1] + + # If the current platform uses compute_capability, + # make sure the kernel supports the compute cability. + if compute_capability is not None: + kernel_min_capability = kernel.get_min_capability() + if ( + kernel_min_capability is not None + and kernel_min_capability > compute_capability + ): + return ( + False, + f"{kernel.__name__} requires capability " + f"{kernel_min_capability}, current compute capability " + f"is {compute_capability}", + ) + can_implement, failure_reason = kernel.can_implement(config) + if not can_implement: + return ( + False, + f" {kernel.__name__} cannot be implement because: {failure_reason}", + ) + + return True, "" + def choose_scaled_mm_linear_kernel( - config: ScaledMMLinearLayerConfig, compute_capability: int | None = None -) -> type[ScaledMMLinearKernel]: + config: _KernelConfigT, + possible_kernels: dict[PlatformEnum, list[type[_KernelT]]], + compute_capability: int | None = None, + force_kernel: type[_KernelT] | None = None, +) -> type[_KernelT]: """ - Choose an ScaledMMLinearKernel that can implement the given config for the + Choose a _KernelT that can implement the given config for the given compute capability. Attempts to choose the best kernel in terms of performance. Args: - config (ScaledMMLinearLayerConfig): Description of the linear layer + config (_KernelConfigT): Description of the linear layer to be implemented. + possible_kernels (dict[PlatformEnum, list[_KernelT]]): A + dictionary of platforms and their list list of possible kernels. compute_capability (Optional[int], optional): The compute capability of the target device, if None uses `current_platform` to get the compute capability. Defaults to None. + force_kernel (Optional[type[_KernelT]]): An Optional forced kernel to override + the possible_kernels if it can be implemented. If None, it will only try the + possible kernels. Raises: ValueError: If no kernel can implement the given config. Returns: - type[ScaledMMLinearKernel]: Chosen kernel. + _KernelT: Chosen kernel. """ - failure_reasons = [] - for kernel in _POSSIBLE_KERNELS[current_platform._enum]: - if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","): - failure_reasons.append(f"{kernel.__name__}: disabled by env var") - continue + failure_reason_list = [] - # If the current platform uses compute_capability, - # make sure the kernel supports the compute capability. - is_supported, reason = kernel.is_supported(compute_capability) - if not is_supported: - failure_reasons.append(f"{kernel.__name__}: {reason}") - continue + if force_kernel is not None: + can_implement, failure_reason = is_supported_and_can_implement_kernel( + force_kernel, config, compute_capability + ) + if can_implement: + return force_kernel - can_implement, reason = kernel.can_implement(config) - if not can_implement: - failure_reasons.append(f"{kernel.__name__}: {reason}") - continue + logger.info_once( + "Tried to force %s, but the kernel couldn't be implemented", + force_kernel.__name__, + scope="global", + ) - return kernel + for kernel in possible_kernels[current_platform._enum]: + is_supported_and_can_implement, failure_reason = ( + is_supported_and_can_implement_kernel(kernel, config, compute_capability) + ) + if is_supported_and_can_implement: + return kernel + failure_reason_list.append(failure_reason) raise ValueError( "Failed to find a kernel that can implement the " - "ScaledMM linear layer. Reasons: \n" + "\n".join(failure_reasons) + "ScaledMM linear layer. Reasons: \n" + "\n".join(failure_reason_list) + ) + + +def init_fp8_linear_kernel( + activation_quant_key: QuantKey, + weight_quant_key: QuantKey, + out_dtype: torch.dtype, + force_kernel: type[FP8ScaledMMLinearKernel] | None = None, + module_name: str | None = None, +) -> FP8ScaledMMLinearKernel: + scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig( + weight_quant_key=weight_quant_key, + activation_quant_key=activation_quant_key, + out_dtype=out_dtype, + ) + + kernel_type = choose_scaled_mm_linear_kernel( + scaled_mm_linear_kernel_config, _POSSIBLE_FP8_KERNELS, force_kernel=force_kernel + ) + + if module_name: + logger.info_once( + "Selected %s for %s", + kernel_type.__name__, + module_name, + scope="global", + ) + + return kernel_type( + scaled_mm_linear_kernel_config, + layer_param_names=["weight", "weight_scale", "input_scale", "input_scale_ub"], + ) + + +def init_int8_linear_kernel( + is_channelwise: bool, + is_static_input_scheme: bool, + input_symmetric: bool, + module_name: str, +) -> Int8ScaledMMLinearKernel: + config = Int8ScaledMMLinearLayerConfig( + is_channelwise=is_channelwise, + is_static_input_scheme=is_static_input_scheme, + input_symmetric=input_symmetric, + ) + + kernel_type = choose_scaled_mm_linear_kernel( + config, + _POSSIBLE_INT8_KERNELS, + ) + + logger.info_once( + "Selected %s for %s", + kernel_type.__class__.__name__, + module_name, + scope="global", + ) + + return kernel_type( + config, + layer_param_names=[ + "weight", + "weight_scale", + "input_scale", + "input_zero_point", + "azp_adj", + ], ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py index 971bd2005a23b..e3f94eaa7e847 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -9,27 +9,22 @@ from vllm._aiter_ops import rocm_aiter_ops from vllm.platforms import current_platform from .cutlass import CutlassScaledMMLinearKernel -from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig +from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): @classmethod - def is_supported( - cls, compute_capability: int | None = None - ) -> tuple[bool, str | None]: - if not current_platform.is_rocm(): - return ( - False, - "AiterScaledMMLinearKernel requires `aiter` which is not " - + "currently supported on non-ROCm platform.", - ) - if compute_capability is None: - _cc = current_platform.get_device_capability() - if _cc is not None: - compute_capability = _cc.major * 10 + _cc.minor - if compute_capability is not None and compute_capability < 90: - return False, f"requires capability 90, got {compute_capability}" + def get_min_capability(cls) -> int: + return 90 + @classmethod + def is_platform_supported(cls) -> tuple[bool, str | None]: + if not current_platform.is_rocm(): + return False, "ROCm" + return True, None + + @classmethod + def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: try: import aiter # noqa: F401 # deliberately attempt to import aiter except Exception: @@ -48,10 +43,6 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): + "`VLLM_ROCM_USE_AITER_LINEAR` default is True.", ) - return True, None - - @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: if not c.input_symmetric: return ( False, @@ -59,9 +50,6 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): ) return True, None - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - super().process_weights_after_loading(layer) - def apply_weights( self, layer: torch.nn.Module, @@ -78,7 +66,7 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support ATIER block scaled GEMM and mix-precision GEMM. """ - w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) + w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer) # ops.scaled_int8_quant supports both dynamic and static quant: # * dynamic, i_s is None and x_s computed from x. diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py index 6401b94d6278b..a9f1c5b2c9a77 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py @@ -14,24 +14,34 @@ from vllm.model_executor.layers.utils import check_cpu_sgl_kernel from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum -from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig +from .ScaledMMLinearKernel import ( + Int8ScaledMMLinearKernel, + Int8ScaledMMLinearLayerConfig, +) -class CPUScaledMMLinearKernel(ScaledMMLinearKernel): +class CPUScaledMMLinearKernel(Int8ScaledMMLinearKernel): @classmethod - def is_supported( - cls, compute_capability: int | None = None + def get_min_capability(cls) -> int: + # current_platform.get_device_capability() returns None + # so the check will be ignored + return -1 + + @classmethod + def is_platform_supported( + cls, ) -> tuple[bool, str | None]: if not current_platform.is_cpu(): - return False, "Requires CPU." + return False, "CPU" return True, None @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: return True, None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - weight = getattr(layer, self.w_q_name) + w_q_name, _, _, _, _ = self.layer_param_names + weight = getattr(layer, w_q_name) dtype = weight.dtype N, K = weight.size() if ( @@ -49,10 +59,11 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel): def process_weights_for_onednn(self, layer: torch.nn.Module) -> None: # WEIGHT # Transpose to [K, N] for convenience - weight = getattr(layer, self.w_q_name) + w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names + weight = getattr(layer, w_q_name) replace_parameter( layer, - self.w_q_name, + w_q_name, torch.nn.Parameter(weight.t().data, requires_grad=False), ) @@ -61,28 +72,27 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel): # If we have a fused module (QKV, MLP) with per tensor scales (thus N # scales being passed to the kernel), convert to the per-channel case. is_fused_module = len(layer.logical_widths) > 1 - weight_scale = getattr(layer, self.w_s_name) + weight_scale = getattr(layer, w_s_name) if is_fused_module and not self.config.is_channelwise: weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) replace_parameter( layer, - self.w_s_name, + w_s_name, torch.nn.Parameter(weight_scale.data, requires_grad=False), ) # INPUT SCALE if self.config.is_static_input_scheme: - input_scale = getattr(layer, self.i_s_name) + input_scale = getattr(layer, i_s_name) if self.config.input_symmetric: replace_parameter( layer, - self.i_s_name, + i_s_name, torch.nn.Parameter(input_scale.max(), requires_grad=False), ) - setattr(layer, self.i_zp_name, None) else: - input_zero_point = getattr(layer, self.i_zp_name) + input_zero_point = getattr(layer, i_zp_name) # reconstruct the ranges int8_traits = torch.iinfo(torch.int8) @@ -92,20 +102,16 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel): scale = (range_max - range_min) / (int8_traits.max - int8_traits.min) replace_parameter( - layer, self.i_s_name, torch.nn.Parameter(scale, requires_grad=False) + layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False) ) azp = ( (int8_traits.min - range_min / scale).round().to(dtype=torch.int32) ) replace_parameter( - layer, self.i_zp_name, torch.nn.Parameter(azp, requires_grad=False) + layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False) ) - else: - setattr(layer, self.i_s_name, None) - setattr(layer, self.i_zp_name, None) - # Different from cutlass, oneDNN kernels only need the AZP adjustment # term for dynamic quantization. And s_b should be folded into the # term. Such as: @@ -113,38 +119,37 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel): # s_a * (s_b * AB) - s_a * s_b * zp_a * B + bias = # s_a * GEMM_output - s_a * zp_a * adj + bias if not (self.config.input_symmetric and self.config.is_static_input_scheme): - weight = getattr(layer, self.w_q_name) - weight_scale = getattr(layer, self.w_s_name) + weight = getattr(layer, w_q_name) + weight_scale = getattr(layer, w_s_name) azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.float32) azp_adj = azp_adj * weight_scale.squeeze() setattr( layer, - self.azp_adj_name, + azp_adj_name, torch.nn.Parameter(azp_adj, requires_grad=False), ) - else: - setattr(layer, self.azp_adj_name, None) - weight = getattr(layer, self.w_q_name) + weight = getattr(layer, w_q_name) self.dnnl_handler = ops.create_onednn_scaled_mm( weight, - getattr(layer, self.w_s_name), + getattr(layer, w_s_name), torch.get_default_dtype(), - getattr(layer, self.i_s_name) is None, + getattr(layer, i_s_name) is None, not self.config.input_symmetric, 32, ) # weight is prepacked and maintained by the dnnl_handler, # release the original weight - setattr(layer, self.w_q_name, None) + setattr(layer, w_q_name, None) del weight def process_weights_for_sgl(self, layer: torch.nn.Module) -> None: + w_q_name, w_s_name, _, _, _ = self.layer_param_names # WEIGHT - weight = getattr(layer, self.w_q_name) + weight = getattr(layer, w_q_name) packed_weight = torch.ops._C.convert_weight_packed(weight) replace_parameter( - layer, self.w_q_name, torch.nn.Parameter(packed_weight, requires_grad=False) + layer, w_q_name, torch.nn.Parameter(packed_weight, requires_grad=False) ) if layer.bias is not None: @@ -156,19 +161,15 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel): # WEIGHT SCALE # CPU SGL kernels only support per-channel. # For per-tensor quant, convert to the per-channel case. - weight_scale = getattr(layer, self.w_s_name) + weight_scale = getattr(layer, w_s_name) if not self.config.is_channelwise: weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) replace_parameter( layer, - self.w_s_name, + w_s_name, torch.nn.Parameter(weight_scale.data, requires_grad=False), ) - setattr(layer, self.i_s_name, None) - setattr(layer, self.i_zp_name, None) - setattr(layer, self.azp_adj_name, None) - def apply_weights( self, layer: torch.nn.Module, @@ -187,7 +188,7 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel): x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) + w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer) # ops.scaled_int8_quant supports both dynamic and static quant: # * dynamic, i_s is None and x_s computed from x. @@ -209,7 +210,7 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel): x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - w_q, w_s, _, _, _ = self._get_weight_params(layer) + w_q, w_s, _, _, _ = self._get_layer_params(layer) return torch.ops._C.int8_scaled_mm_with_quant( x, w_q, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index 2f00e0df8ed47..3ea5798be1735 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -2,6 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable + import torch from vllm import _custom_ops as ops @@ -11,35 +13,51 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( ) from vllm.platforms import current_platform -from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig +from .ScaledMMLinearKernel import ( + FP8ScaledMMLinearKernel, + FP8ScaledMMLinearLayerConfig, + Int8ScaledMMLinearKernel, + Int8ScaledMMLinearLayerConfig, +) -class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): +def cutlass_w8a8_scaled_mm_fp8( + *, + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor, + output_shape: list, +) -> torch.Tensor: + # Fused GEMM_DQ + output = ops.cutlass_scaled_mm( + A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias + ) + return output.view(*output_shape) + + +class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel): @classmethod - def is_supported( - cls, compute_capability: int | None = None - ) -> tuple[bool, str | None]: + def is_platform_supported(cls) -> tuple[bool, str | None]: if not current_platform.is_cuda(): - return False, "Requires CUDA." - if compute_capability is None: - _cc = current_platform.get_device_capability() - if _cc is not None: - compute_capability = _cc.major * 10 + _cc.minor - if compute_capability is not None and compute_capability < 75: - return False, f"requires capability 75, got {compute_capability}" + return False, "CUDA" return True, None @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: return True, None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names + config = self.config # WEIGHT # Cutlass kernels need transposed weight. - weight = getattr(layer, self.w_q_name) + weight = getattr(layer, w_q_name) replace_parameter( layer, - self.w_q_name, + w_q_name, torch.nn.Parameter(weight.t().data, requires_grad=False), ) @@ -48,28 +66,28 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): # If we have a fused module (QKV, MLP) with per tensor scales (thus N # scales being passed to the kernel), convert to the per-channel case. is_fused_module = len(layer.logical_widths) > 1 - weight_scale = getattr(layer, self.w_s_name) - if is_fused_module and not self.config.is_channelwise: + weight_scale = getattr(layer, w_s_name) + if is_fused_module and not config.is_channelwise: weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) replace_parameter( layer, - self.w_s_name, + w_s_name, torch.nn.Parameter(weight_scale.data, requires_grad=False), ) # INPUT SCALE - if self.config.is_static_input_scheme: - input_scale = getattr(layer, self.i_s_name) + if config.is_static_input_scheme: + input_scale = getattr(layer, i_s_name) - if self.config.input_symmetric: + if config.input_symmetric: replace_parameter( layer, - self.i_s_name, + i_s_name, torch.nn.Parameter(input_scale.max(), requires_grad=False), ) - setattr(layer, self.i_zp_name, None) + setattr(layer, i_zp_name, None) else: - input_zero_point = getattr(layer, self.i_zp_name) + input_zero_point = getattr(layer, i_zp_name) # reconstruct the ranges int8_traits = torch.iinfo(torch.int8) @@ -79,38 +97,32 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): scale = (range_max - range_min) / (int8_traits.max - int8_traits.min) replace_parameter( - layer, self.i_s_name, torch.nn.Parameter(scale, requires_grad=False) + layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False) ) # AZP loaded as int8 but used as int32 azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32) replace_parameter( - layer, self.i_zp_name, torch.nn.Parameter(azp, requires_grad=False) + layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False) ) - else: - setattr(layer, self.i_s_name, None) - setattr(layer, self.i_zp_name, None) - # azp_adj is the AZP adjustment term, used to account for weights. # It does not depend on scales or azp, so it is the same for # static and dynamic quantization. # For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md # https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md - if not self.config.input_symmetric: - weight = getattr(layer, self.w_q_name) + if not config.input_symmetric: + weight = getattr(layer, w_q_name) azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32) - if self.config.is_static_input_scheme: + if config.is_static_input_scheme: # cutlass_w8a8 requires azp to be folded into azp_adj # in the per-tensor case - azp_adj = getattr(layer, self.i_zp_name) * azp_adj + azp_adj = getattr(layer, i_zp_name) * azp_adj setattr( layer, - self.azp_adj_name, + azp_adj_name, torch.nn.Parameter(azp_adj, requires_grad=False), ) - else: - setattr(layer, self.azp_adj_name, None) def apply_weights( self, @@ -118,7 +130,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) + w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer) # ops.scaled_int8_quant supports both dynamic and static quant: # * dynamic, i_s is None and x_s computed from x. @@ -145,3 +157,21 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): return ops.cutlass_scaled_mm( x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias ) + + +class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): + @classmethod + def is_platform_supported(cls) -> tuple[bool, str | None]: + if not current_platform.is_cuda(): + return False, "CUDA" + return True, None + + @classmethod + def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + return True, None + + def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]: + return cutlass_w8a8_scaled_mm_fp8 + + def get_ouput_padding(self) -> int | None: + return None diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/flashinfer.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flashinfer.py new file mode 100644 index 0000000000000..f595c194ad590 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flashinfer.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Callable + +import torch + +from vllm.platforms import current_platform +from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer + +from .ScaledMMLinearKernel import ( + FP8ScaledMMLinearKernel, + FP8ScaledMMLinearLayerConfig, +) + + +def flashinfer_w8a8_scaled_mm( + *, + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor, + output_shape: list, +) -> torch.Tensor: + return flashinfer_scaled_fp8_mm( + A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias + ) + + +class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel): + @classmethod + def is_platform_supported(cls) -> tuple[bool, str | None]: + if not current_platform.is_cuda(): + return False, "CUDA" + return True, None + + @classmethod + def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + per_tensor_activation_scales = ( + c.activation_quant_key.scale.group_shape.is_per_tensor() + ) + per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor() + + if not has_flashinfer(): + return ( + False, + "FlashInferScaledMMLinearKernel requires " + + "FlashInfer to be installed.", + ) + if not has_flashinfer(): + return ( + False, + "FlashInferScaledMMLinearKernel requires " + + "FlashInfer to be installed.", + ) + + if not (per_tensor_activation_scales and per_tensor_weight_scales): + return ( + False, + "FlashInferScaledMMLinearKernel requires " + + "per tensor activation and weight scales.", + ) + return True, None + + @classmethod + def get_min_capability(cls) -> int: + return 100 + + def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]: + return flashinfer_w8a8_scaled_mm + + def get_ouput_padding(self) -> int | None: + return None diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py new file mode 100644 index 0000000000000..f8a06103abaf3 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py @@ -0,0 +1,238 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Callable + +import torch +from packaging import version + +from vllm.config import CompilationMode, get_current_vllm_config +from vllm.platforms import current_platform + +from .ScaledMMLinearKernel import ( + FP8ScaledMMLinearKernel, + FP8ScaledMMLinearLayerConfig, +) + + +def torch_per_tensor_w8a8_scaled_mm( + *, + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor, + output_shape: list, +) -> torch.Tensor: + output = torch._scaled_mm( + A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias + ) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + output = output[0] + + return torch.narrow(output, 0, 0, output_shape[0]).view(*output_shape) + + +def torch_row_wise_w8a8_scaled_mm( + *, + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor, + output_shape: list, +) -> torch.Tensor: + # Note: + # For now it has only been validated on ROCm platform. + # fp8 rowwise scaling in torch._scaled_mm is introduced in + # https://github.com/pytorch/pytorch/pull/144432 using + # hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above. + # + # For CUDA platform please validate if the torch._scaled_mm supports + # rowwise scaled GEMM before using it + + # Fused GEMM_DQ Rowwise GEMM + output = torch._scaled_mm( + A, + B, + out_dtype=out_dtype, + scale_a=As, + scale_b=Bs.t(), + bias=bias, + ) + + output = torch.narrow(output, 0, 0, output_shape[0]) + output = output.view(*output_shape) + return output + + +def torch_channelwise_w8a8_scaled_mm( + *, + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor, + output_shape: list, +) -> torch.Tensor: + # Use unfused DQ due to limitations with scaled_mm + + # Symmetric quantized GEMM by definition computes the following: + # C = (s_x * X) (s_w * W) + bias + # This is equivalent to dequantizing the weights and activations + # before applying a GEMM. + # + # In order to compute quantized operands, a quantized kernel + # will rewrite the above like so: + # C = s_w * s_x * (X * W) + bias + # + # For the scaled_mm fallback case, we break this down, since it + # does not support s_w being a vector. + + # Input scaling factors are no longer optional in _scaled_mm starting + # from pytorch 2.5. Allocating a dummy tensor to pass as scales + dummy_tensor = torch.ones(1, dtype=torch.float32, device=A.device) + + # GEMM + # This computes C = (X * W). + # Output in fp32 to allow subsequent ops to happen in-place + output = torch._scaled_mm( + A, + B, + scale_a=dummy_tensor, + scale_b=dummy_tensor, + out_dtype=torch.float32, + ) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + output = output[0] + # Unpad (undo num_token_padding) + output = torch.narrow(output, 0, 0, output_shape[0]) + x_scale = torch.narrow(As, 0, 0, output_shape[0]) + + # DQ + # C = sw * sx * (X * W) + bias + output = output * x_scale * Bs.t() + if bias is not None: + output = output + bias + return output.to(out_dtype).view(*output_shape) + + +class TorchScaledMMLinearKernel(FP8ScaledMMLinearKernel): + """ + Base class for FP8 linear kernels using Torch. + Each subclass represents a kernel variant for + specific device capabilities and torch versions, + so we split them up and implement + get_min_capability() separately for each. + """ + + @classmethod + def is_platform_supported( + cls, + ) -> tuple[bool, str | None]: + if not current_platform.is_cuda_alike(): + return False, "ROCm or CUDA" + return True, None + + def get_ouput_padding(self) -> int | None: + # Note: we pad the input because torch._scaled_mm is more performant + # for matrices with batch dimension > 16. + # This could change in the future. + # We also don't pad when using torch.compile, + # as it breaks with dynamic shapes. + vllm_config = get_current_vllm_config().compilation_config + pad_output = vllm_config.mode < CompilationMode.VLLM_COMPILE + output_padding = 17 if pad_output else None + return output_padding + + +class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): + @classmethod + def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + per_tensor_activation_scales = ( + c.activation_quant_key.scale.group_shape.is_per_tensor() + ) + per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor() + + if not (per_tensor_activation_scales and per_tensor_weight_scales): + return ( + False, + "PerTensorTorchScaledMMLinearKernel requires " + + "per tensor activation and weight scales.", + ) + return True, None + + def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]: + return torch_per_tensor_w8a8_scaled_mm + + +class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): + @classmethod + def get_min_capability(cls) -> int: + return 94 + + @classmethod + def is_platform_supported(cls) -> tuple[bool, str | None]: + if not current_platform.is_rocm(): + return False, "ROCm" + return True, None + + @classmethod + def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + per_tensor_activation_scales = ( + c.activation_quant_key.scale.group_shape.is_per_tensor() + ) + per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor() + + if c.out_dtype == torch.float16: + # hipblaslt rowwise _scaled_mm only supports BFloat16 + return ( + False, + "RowWiseTorchScaledMMLinearKernel only supports BFloat16.", + ) + + if per_tensor_activation_scales or per_tensor_weight_scales: + return ( + False, + "RowWiseTorchScaledMMLinearKernel cannot be used with " + + "per tensor activation and weight scales.", + ) + + if not version.parse(torch.__version__) >= version.parse("2.7"): + return ( + False, + "RowWiseTorchScaledMMLinearKernel requires " + "pytorch version >=2.7.", + ) + + return True, None + + def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]: + return torch_row_wise_w8a8_scaled_mm + + +class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): + @classmethod + def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + per_tensor_activation_scales = ( + c.activation_quant_key.scale.group_shape.is_per_tensor() + ) + per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor() + + if per_tensor_activation_scales and per_tensor_weight_scales: + return ( + False, + "ChannelWiseTorchScaledMMLinearKernel cannot be used with " + + "per tensor activation and weight scales.", + ) + + return True, None + + def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]: + return torch_channelwise_w8a8_scaled_mm diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py new file mode 100644 index 0000000000000..81cfede4e16fd --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Callable + +import torch + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.platforms import current_platform +from vllm.utils.platform_utils import get_cu_count +from vllm.utils.torch_utils import direct_register_custom_op + +from .ScaledMMLinearKernel import ( + FP8ScaledMMLinearKernel, + FP8ScaledMMLinearLayerConfig, +) + + +def rocm_per_tensor_float_w8a8_scaled_mm_impl( + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor, +) -> torch.Tensor: + if ( + A.shape[0] == 1 + and B.shape[1] % 16 == 0 + and ((bias is None) or (bias.dtype == out_dtype)) + ): + output = ops.wvSplitKQ( + B.t(), + A, + out_dtype, + As, + Bs, + get_cu_count(), + bias, + ) + # Fallback + else: + output = torch._scaled_mm( + A, + B, + out_dtype=out_dtype, + scale_a=As, + scale_b=Bs, + bias=bias, + ) + return output + + +def rocm_per_tensor_float_w8a8_scaled_mm_fake( + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor, +) -> torch.Tensor: + return A.new_empty((*A.shape[:-1], B.shape[1]), dtype=out_dtype) + + +def rocm_per_tensor_float_w8a8_scaled_mm( + *, + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor, + output_shape: list[int], +) -> torch.Tensor: + output = torch.ops.vllm.rocm_per_tensor_float_w8a8_scaled_mm_impl( + A, B, out_dtype, As, Bs, bias + ) + return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape) + + +if current_platform.is_rocm(): + direct_register_custom_op( + op_name="rocm_per_tensor_float_w8a8_scaled_mm_impl", + op_func=rocm_per_tensor_float_w8a8_scaled_mm_impl, + fake_impl=rocm_per_tensor_float_w8a8_scaled_mm_fake, + ) + + +class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel): + @classmethod + def is_platform_supported(cls) -> tuple[bool, str | None]: + if not current_platform.is_rocm(): + return False, "ROCm" + + from vllm.platforms.rocm import on_mi3xx + + if not on_mi3xx(): + return False, "ROCm MI3xx" + + return True, None + + @classmethod + def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + per_tensor_activation_scales = ( + c.activation_quant_key.scale.group_shape.is_per_tensor() + ) + per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor() + + if not envs.VLLM_ROCM_USE_SKINNY_GEMM: + return ( + False, + "VLLM_ROCM_USE_SKINNY_GEMM must be enabled " + + "to use ROCmScaledMMLinearKernel.", + ) + + if not (per_tensor_activation_scales and per_tensor_weight_scales): + return ( + False, + "ROCmScaledMMLinearKernel requires " + + "per tensor activation and weight scales.", + ) + return True, None + + def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]: + return rocm_per_tensor_float_w8a8_scaled_mm + + def get_ouput_padding(self) -> int | None: + return None diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py index 760f1f7f79576..1eb36e41b1ea8 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py @@ -11,46 +11,49 @@ from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.platforms import current_platform -from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig +from .cutlass import CutlassScaledMMLinearKernel +from .ScaledMMLinearKernel import ( + Int8ScaledMMLinearLayerConfig, +) -class TritonScaledMMLinearKernel(ScaledMMLinearKernel): +class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel): @classmethod - def is_supported( - cls, compute_capability: int | None = None - ) -> tuple[bool, str | None]: + def is_platform_supported(cls) -> tuple[bool, str | None]: if current_platform.is_cuda_alike(): return True, None - return False, "Requires ROCm or CUDA." + return False, "ROCm or CUDA" @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: if not c.input_symmetric: return False, "Only symmetric input is supported." return True, None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - weight = getattr(layer, self.w_q_name) + w_q, _, i_s, _, _ = self._get_layer_params(layer) + w_q_name, _, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names + replace_parameter( layer, - self.w_q_name, - torch.nn.Parameter(weight.t().data, requires_grad=False), + w_q_name, + torch.nn.Parameter(w_q.t().data, requires_grad=False), ) # INPUT SCALE if self.config.is_static_input_scheme: - input_scale = getattr(layer, self.i_s_name) + assert i_s is not None replace_parameter( layer, - self.i_s_name, - torch.nn.Parameter(input_scale.max(), requires_grad=False), + i_s_name, + torch.nn.Parameter(i_s.max(), requires_grad=False), ) - setattr(layer, self.i_zp_name, None) + setattr(layer, i_zp_name, None) else: - setattr(layer, self.i_s_name, None) - setattr(layer, self.i_zp_name, None) + setattr(layer, i_s_name, None) + setattr(layer, i_zp_name, None) - setattr(layer, self.azp_adj_name, None) + setattr(layer, azp_adj_name, None) def apply_weights( self, @@ -58,7 +61,7 @@ class TritonScaledMMLinearKernel(ScaledMMLinearKernel): x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) + w_q, w_s, i_s, i_zp, _ = self._get_layer_params(layer) x_q, x_s, x_zp = ops.scaled_int8_quant( x.contiguous(), i_s, i_zp, symmetric=True diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py index 0be858c51993d..4a44c7640d16c 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -12,23 +12,21 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( ) from vllm.platforms import current_platform -from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig +from .ScaledMMLinearKernel import ( + Int8ScaledMMLinearKernel, + Int8ScaledMMLinearLayerConfig, +) -class XLAScaledMMLinearKernel(ScaledMMLinearKernel): +class XLAScaledMMLinearKernel(Int8ScaledMMLinearKernel): @classmethod - def is_supported( - cls, compute_capability: int | None = None - ) -> tuple[bool, str | None]: + def is_platform_supported(cls) -> tuple[bool, str | None]: if not current_platform.is_tpu(): - return False, "Requires TPU." + return False, "TPU" return True, None @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: - if not current_platform.is_tpu(): - return False, "ScaledMMXLA requires running on TPU." - + def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: if c.is_static_input_scheme: return False, "ScaledMMXLA requires dynamic activation scales." @@ -43,9 +41,10 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # WEIGHT # [out, in] (different than cutlass_scaled_mm) - weight = getattr(layer, self.w_q_name) + w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names + weight = getattr(layer, w_q_name) replace_parameter( - layer, self.w_q_name, torch.nn.Parameter(weight.data, requires_grad=False) + layer, w_q_name, torch.nn.Parameter(weight.data, requires_grad=False) ) # WEIGHT SCALE @@ -53,7 +52,7 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel): # If we have a fused module (QKV, MLP) with per tensor scales (thus N # scales being passed to the kernel), convert to the per-channel case. is_fused_module = len(layer.logical_widths) > 1 - weight_scale = getattr(layer, self.w_s_name) + weight_scale = getattr(layer, w_s_name) if is_fused_module and not self.config.is_channelwise: weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) @@ -61,14 +60,14 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel): weight_scale = weight_scale.squeeze(-1) replace_parameter( layer, - self.w_s_name, + w_s_name, torch.nn.Parameter(weight_scale.data, requires_grad=False), ) # Only support symmetric dynamic activation quantization. - setattr(layer, self.i_s_name, None) - setattr(layer, self.i_zp_name, None) - setattr(layer, self.azp_adj_name, None) + setattr(layer, i_s_name, None) + setattr(layer, i_zp_name, None) + setattr(layer, azp_adj_name, None) # Filter warning for cond usage in apply_weights. It is okay # to specialize the graph since bias is not dynamic. @@ -89,7 +88,7 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel): x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - w_q, w_s, _, _, _ = self._get_weight_params(layer) + w_q, w_s, _, _, _ = self._get_layer_params(layer) # Required to register custom ops. import torch_xla.experimental.custom_kernel # noqa: F401 diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index afbefe1fedc18..edb7dfd94b0e6 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -34,6 +34,9 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + init_fp8_linear_kernel, +) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( build_flashinfer_fp4_cutlass_moe_prepare_finalize, @@ -71,10 +74,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, cutlass_fp4_supported, is_layer_skipped, + kFp8DynamicTokenSym, + kFp8StaticTensorSym, + kFp8StaticTokenSym, swizzle_blockscale, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, cutlass_block_fp8_supported, requantize_with_max_scale, ) @@ -431,8 +436,11 @@ class ModelOptFp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: ModelOptFp8Config) -> None: self.quant_config = quant_config - self.fp8_linear = Fp8LinearOp( - act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR + self.fp8_linear = init_fp8_linear_kernel( + activation_quant_key=kFp8StaticTensorSym, + weight_quant_key=kFp8StaticTensorSym, + out_dtype=torch.get_default_dtype(), + module_name=self.__class__.__name__, ) def create_weights( @@ -500,13 +508,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase): x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - return self.fp8_linear.apply( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - input_scale=layer.input_scale, - bias=bias, - ) + return self.fp8_linear.apply_weights(layer, x, bias) class ModelOptFp8PcPtLinearMethod(LinearMethodBase): @@ -520,8 +522,11 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase): def __init__(self, quant_config: ModelOptFp8Config) -> None: self.quant_config = quant_config - self.fp8_linear = Fp8LinearOp( - act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN + self.fp8_linear = init_fp8_linear_kernel( + activation_quant_key=kFp8DynamicTokenSym, + weight_quant_key=kFp8StaticTokenSym, + out_dtype=torch.get_default_dtype(), + module_name=self.__class__.__name__, ) def create_weights( @@ -578,13 +583,7 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase): x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - return self.fp8_linear.apply( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - input_scale=None, - bias=bias, - ) + return self.fp8_linear.apply_weights(layer, x, bias) class ModelOptFp8PbWoLinearMethod(LinearMethodBase): diff --git a/vllm/model_executor/layers/quantization/ptpc_fp8.py b/vllm/model_executor/layers/quantization/ptpc_fp8.py index ed8a2c7fa0841..50b098068906d 100644 --- a/vllm/model_executor/layers/quantization/ptpc_fp8.py +++ b/vllm/model_executor/layers/quantization/ptpc_fp8.py @@ -17,11 +17,13 @@ from vllm.model_executor.layers.quantization.fp8 import ( Fp8KVCacheMethod, Fp8LinearMethod, ) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, - is_layer_skipped, +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + init_fp8_linear_kernel, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + is_layer_skipped, + kFp8DynamicTokenSym, ) -from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp from vllm.platforms import current_platform ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -97,9 +99,11 @@ class PTPCFp8LinearMethod(Fp8LinearMethod): ) super().__init__(quant_config=quant_config) # Force weight quantization - self.quant_config.is_checkpoint_fp8_serialized = False - self.fp8_linear = Fp8LinearOp( - act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN + self.fp8_linear = init_fp8_linear_kernel( + activation_quant_key=kFp8DynamicTokenSym, + weight_quant_key=kFp8DynamicTokenSym, + out_dtype=torch.get_default_dtype(), + module_name=self.__class__.__name__, ) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: @@ -126,11 +130,4 @@ class PTPCFp8LinearMethod(Fp8LinearMethod): x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - return self.fp8_linear.apply( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - input_scale=None, - input_scale_ub=None, - bias=bias, - ) + return self.fp8_linear.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index 1e5ee93b61f2b..819348c5b938e 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -7,10 +7,18 @@ from typing import Any, cast import torch from torch.nn import Parameter +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + init_fp8_linear_kernel, +) from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + kFp8DynamicTokenSym, + kFp8StaticTensorSym, + kFp8StaticTokenSym, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale, ) @@ -23,6 +31,8 @@ from vllm.platforms import current_platform __all__ = ["QuarkW8A8Fp8"] +logger = init_logger(__name__) + class QuarkW8A8Fp8(QuarkScheme): def __init__( @@ -35,15 +45,16 @@ class QuarkW8A8Fp8(QuarkScheme): self.is_static_input_scheme = not cast(bool, input_config.get("is_dynamic")) self.input_qscheme = cast(str, input_config.get("qscheme")) - per_token = ( + per_token_activation = ( not self.is_static_input_scheme and self.input_qscheme == "per_channel" ) - self.act_quant_group_shape = ( - GroupShape.PER_TOKEN if per_token else GroupShape.PER_TENSOR + per_token_weight = self.weight_qscheme == "per_channel" + + self.activation_quant_key = ( + kFp8DynamicTokenSym if per_token_activation else kFp8StaticTensorSym ) - self.fp8_linear = Fp8LinearOp( - act_quant_static=self.is_static_input_scheme, - act_quant_group_shape=self.act_quant_group_shape, + self.weight_quant_key = ( + kFp8StaticTokenSym if per_token_weight else kFp8StaticTensorSym ) self.out_dtype = torch.get_default_dtype() @@ -94,7 +105,7 @@ class QuarkW8A8Fp8(QuarkScheme): layer.input_scale = Parameter(input_scale, requires_grad=False) else: weight_scale = layer.weight_scale.data - if self.act_quant_group_shape == GroupShape.PER_TOKEN: + if self.activation_quant_key.scale.group_shape == GroupShape.PER_TOKEN: weight_scale = weight_scale.view(-1, 1) layer.weight = Parameter(weight.t(), requires_grad=False) # required by torch.compile to be torch.nn.Parameter @@ -163,17 +174,19 @@ class QuarkW8A8Fp8(QuarkScheme): input_scale[:] = torch.finfo(torch.float32).min layer.register_parameter("input_scale", input_scale) + layer.input_scale_ub = None + + self.fp8_linear = init_fp8_linear_kernel( + activation_quant_key=self.activation_quant_key, + weight_quant_key=self.weight_quant_key, + out_dtype=torch.get_default_dtype(), + module_name=self.__class__.__name__, + ) + def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - return self.fp8_linear.apply( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - out_dtype=self.out_dtype, - input_scale=layer.input_scale, - bias=bias, - ) + return self.fp8_linear.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py index 42d2ed2e85ed9..a7a7726bae0e2 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py @@ -7,8 +7,7 @@ import torch from vllm.logger import init_logger from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( - ScaledMMLinearLayerConfig, - choose_scaled_mm_linear_kernel, + init_int8_linear_kernel, ) from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme from vllm.model_executor.parameter import ( @@ -22,8 +21,6 @@ logger = init_logger(__name__) class QuarkW8A8Int8(QuarkScheme): - _kernel_backends_being_used: set[str] = set() - def __init__( self, qscheme: str, @@ -50,18 +47,13 @@ class QuarkW8A8Int8(QuarkScheme): ): layer.logical_widths = output_partition_sizes - scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig( + self.kernel = init_int8_linear_kernel( is_channelwise=(self.qscheme == "per_channel"), is_static_input_scheme=(self.is_static_input_scheme is True), input_symmetric=(self.input_symmetric is True), + module_name=self.__class__.__name__, ) - kernel_type = choose_scaled_mm_linear_kernel(scaled_mm_linear_kernel_config) - - if kernel_type.__name__ not in self._kernel_backends_being_used: - logger.info("Using %s for QuarkW8A8Int8", kernel_type.__name__) - self._kernel_backends_being_used.add(kernel_type.__name__) - # WEIGHT weight = ModelWeightParameter( data=torch.empty( @@ -102,25 +94,21 @@ class QuarkW8A8Int8(QuarkScheme): layer.register_parameter("weight_zero_point", weight_zero_point) # INPUT SCALE + input_zero_point = None + input_scale = None if self.is_static_input_scheme: input_scale = BasevLLMParameter( data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader ) - layer.register_parameter("input_scale", input_scale) input_zero_point = BasevLLMParameter( data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader ) - layer.register_parameter("input_zero_point", input_zero_point) - self.kernel = kernel_type( - c=scaled_mm_linear_kernel_config, - w_q_param_name="weight", - w_s_param_name="weight_scale", - i_s_param_name="input_scale", - i_zp_param_name="input_zero_point", - azp_adj_param_name="azp_adj", - ) + layer.register_parameter("input_scale", input_scale) + layer.register_parameter("input_zero_point", input_zero_point) + if not hasattr(layer, "azp_adj"): + layer.register_parameter("azp_adj", None) # Checkpoints are serialized in quark format, which is # different from the format the kernel may want. Handle repacking here. diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index d01263f82007d..e7390a9322fb0 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -109,6 +109,9 @@ kFp8StaticTensorSym = QuantKey(FP8_DTYPE, kStaticTensorScale, symmetric=True) kDynamicTensorScale = ScaleDesc(torch.float32, False, GroupShape.PER_TENSOR) kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, kDynamicTensorScale, symmetric=True) +kStaticTokenScale = ScaleDesc(torch.float32, True, GroupShape.PER_TOKEN) +kFp8StaticTokenSym = QuantKey(FP8_DTYPE, kStaticTokenScale, symmetric=True) + kDynamicTokenScale = ScaleDesc(torch.float32, False, GroupShape.PER_TOKEN) kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 4287922417c63..f949c0c076e71 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -1,34 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable import torch -from packaging import version from vllm import _custom_ops as ops -from vllm import envs -from vllm.config import CompilationMode, get_current_vllm_config -from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.platforms import current_platform -from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer -from vllm.utils.platform_utils import get_cu_count -from vllm.utils.torch_utils import direct_register_custom_op - -# Input scaling factors are no longer optional in _scaled_mm starting -# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale -TORCH_DEVICE_IDENTITY = None - -# The condition to determine if it is on a platform that supports -# torch._scaled_mm rowwise feature. -# The condition is determined once as the operations -# are time-consuming. -USE_ROWWISE_TORCH_SCALED_MM = ( - current_platform.is_rocm() - and version.parse(torch.__version__) >= version.parse("2.7") - and current_platform.has_device_capability(94) -) def sparse_cutlass_supported() -> bool: @@ -140,361 +117,6 @@ def requantize_with_max_scale( return max_w_scale, weight -def maybe_create_device_identity(): - # Allocate dummy ones tensor for torch._scaled_mm - global TORCH_DEVICE_IDENTITY - if TORCH_DEVICE_IDENTITY is None: - TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) - - -def cutlass_w8a8_scaled_mm( - *, - qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - bias: torch.Tensor, - output_shape: list, - **kwargs, -) -> torch.Tensor: - # Fused GEMM_DQ - output = ops.cutlass_scaled_mm( - qinput, weight, out_dtype=out_dtype, scale_a=scale_a, scale_b=scale_b, bias=bias - ) - return output.view(*output_shape) - - -def flashinfer_w8a8_scaled_mm( - *, - qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - bias: torch.Tensor, - output_shape: list, - **kwargs, -) -> torch.Tensor: - return flashinfer_scaled_fp8_mm( - qinput, weight, out_dtype=out_dtype, scale_a=scale_a, scale_b=scale_b, bias=bias - ) - - -def rocm_per_tensor_w8a8_scaled_mm_impl( - qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - bias: torch.Tensor, -) -> torch.Tensor: - from vllm.platforms.rocm import on_mi3xx - - if ( - envs.VLLM_ROCM_USE_SKINNY_GEMM - and on_mi3xx() - and qinput.shape[0] == 1 - and qinput.shape[1] % 16 == 0 - and ((bias is None) or (bias.dtype == out_dtype)) - ): - output = ops.wvSplitKQ( - weight.t(), - qinput, - out_dtype, - scale_a, - scale_b, - get_cu_count(), - bias, - ) - else: - output = torch._scaled_mm( - qinput, - weight, - out_dtype=out_dtype, - scale_a=scale_a, - scale_b=scale_b, - bias=bias, - ) - return output - - -def rocm_per_tensor_w8a8_scaled_mm_fake( - qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - bias: torch.Tensor, -) -> torch.Tensor: - return qinput.new_empty((*qinput.shape[:-1], weight.shape[1]), dtype=out_dtype) - - -def rocm_per_tensor_w8a8_scaled_mm( - *, - qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - bias: torch.Tensor, - output_shape: list, -) -> torch.Tensor: - output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl( - qinput, weight, out_dtype, scale_a, scale_b, bias - ) - return torch.narrow(output, 0, 0, qinput.shape[0]).view(*output_shape) - - -direct_register_custom_op( - op_name="rocm_per_tensor_w8a8_scaled_mm_impl", - op_func=rocm_per_tensor_w8a8_scaled_mm_impl, - fake_impl=rocm_per_tensor_w8a8_scaled_mm_fake, -) - - -def torch_per_tensor_w8a8_scaled_mm( - *, - qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - bias: torch.Tensor, - output_shape: list, -) -> torch.Tensor: - output = torch._scaled_mm( - qinput, weight, out_dtype=out_dtype, scale_a=scale_a, scale_b=scale_b, bias=bias - ) - # A fix for discrepancy in scaled_mm which returns tuple - # for torch < 2.5 and a single value in torch >= 2.5 - if type(output) is tuple and len(output) == 2: - output = output[0] - - return torch.narrow(output, 0, 0, qinput.shape[0]).view(*output_shape) - - -def torch_per_token_w8a8_scaled_mm( - *, - qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - bias: torch.Tensor, - output_shape: list, - **kwargs, -) -> torch.Tensor: - # Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM - # when using it. - # For now it has only been validated on ROCm platform. - # fp8 rowwise scaling in torch._scaled_mm is introduced in - # https://github.com/pytorch/pytorch/pull/144432 using - # hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above. - # - # For CUDA platform please validate if the torch._scaled_mm supports - # rowwise scaled GEMM before using it - - # Fused GEMM_DQ Rowwise GEMM - output = torch._scaled_mm( - qinput, - weight, - out_dtype=out_dtype, - scale_a=scale_a, - scale_b=scale_b.t(), - bias=bias, - ) - - output = torch.narrow(output, 0, 0, qinput.shape[0]) - output = output.view(*output_shape) - return output - - -def torch_channelwise_w8a8_scaled_mm( - *, - qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - bias: torch.Tensor, - output_shape: list, - **kwargs, -) -> torch.Tensor: - # Use unfused DQ due to limitations with scaled_mm - - # Symmetric quantized GEMM by definition computes the following: - # C = (s_x * X) (s_w * W) + bias - # This is equivalent to dequantizing the weights and activations - # before applying a GEMM. - # - # In order to compute quantized operands, a quantized kernel - # will rewrite the above like so: - # C = s_w * s_x * (X * W) + bias - # - # For the scaled_mm fallback case, we break this down, since it - # does not support s_w being a vector. - - # GEMM - # This computes C = (X * W). - # Output in fp32 to allow subsequent ops to happen in-place - output = torch._scaled_mm( - qinput, - weight, - scale_a=TORCH_DEVICE_IDENTITY, - scale_b=TORCH_DEVICE_IDENTITY, - out_dtype=torch.float32, - ) - # A fix for discrepancy in scaled_mm which returns tuple - # for torch < 2.5 and a single value in torch >= 2.5 - if type(output) is tuple and len(output) == 2: - output = output[0] - # Unpad (undo num_token_padding) - output = torch.narrow(output, 0, 0, qinput.shape[0]) - x_scale = torch.narrow(scale_a, 0, 0, qinput.shape[0]) - - # DQ - # C = sw * sx * (X * W) + bias - output = output * x_scale * scale_b.t() - if bias is not None: - output = output + bias - return output.to(out_dtype).view(*output_shape) - - -def dispatch_w8a8_scaled_mm( - preferred_backend: str, per_tensor_weights: bool, per_tensor_activations: bool -) -> Callable[..., torch.Tensor]: - if per_tensor_weights and per_tensor_activations: - if preferred_backend == "rocm": - return rocm_per_tensor_w8a8_scaled_mm - if preferred_backend == "flashinfer": - return flashinfer_w8a8_scaled_mm - if preferred_backend == "cutlass": - return cutlass_w8a8_scaled_mm - return torch_per_tensor_w8a8_scaled_mm - - # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A - if preferred_backend == "cutlass" or preferred_backend == "flashinfer": - return cutlass_w8a8_scaled_mm - - # If torch.scaled_mm supports per-channel (weights) per-token (inputs) - if ( - not per_tensor_weights - and not per_tensor_activations - and USE_ROWWISE_TORCH_SCALED_MM - ): - return torch_per_token_w8a8_scaled_mm - # Normally, torch.scaled_mm supports per tensor weights + activations only - # so fallback to naive if per channel or per token - return torch_channelwise_w8a8_scaled_mm - - -# TODO(luka): follow similar pattern for marlin and block-fp8-linear -# https://github.com/vllm-project/vllm/issues/14397 -class Fp8LinearOp: - """ - This class executes a FP8 linear layer using cutlass if supported and - torch.scaled_mm otherwise. - It needs to be a class instead of a method so that config can be read - in the __init__ method, as reading config is not allowed inside forward. - """ - - def __init__( - self, - act_quant_static: bool, - act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR, - pad_output: bool | None = None, - ): - if current_platform.is_rocm(): - self.preferred_backend = "rocm" - elif current_platform.is_cuda() and cutlass_fp8_supported(): - if has_flashinfer() and current_platform.has_device_capability(100): - self.preferred_backend = "flashinfer" - else: - self.preferred_backend = "cutlass" - else: - self.preferred_backend = "torch" - - # Note: we pad the input because torch._scaled_mm is more performant - # for matrices with batch dimension > 16. - # This could change in the future. - # We also don't pad when using torch.compile, - # as it breaks with dynamic shapes. - if pad_output is None: - config = get_current_vllm_config().compilation_config - pad_output = ( - config.mode < CompilationMode.VLLM_COMPILE - and self.preferred_backend == "torch" - ) - - self.output_padding = 17 if pad_output else None - self.act_quant_static = act_quant_static - self.act_quant_group_shape = act_quant_group_shape - self.quant_fp8 = QuantFP8( - static=act_quant_static, - group_shape=act_quant_group_shape, - num_token_padding=self.output_padding, - ) - - def apply( - self, - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - out_dtype: torch.dtype | None = None, - input_scale: torch.Tensor | None = None, - input_scale_ub: torch.Tensor | None = None, - bias: torch.Tensor | None = None, - ) -> torch.Tensor: - # ops.scaled_fp8_quant supports both dynamic and static quant. - # If dynamic, layer.input_scale is None and x_scale computed from x. - # If static, layer.input_scale is scalar and x_scale is input_scale. - - # View input as 2D matrix for fp8 methods - input_2d = input.view(-1, input.shape[-1]) - output_shape = [*input.shape[:-1], weight.shape[1]] - - if out_dtype is None: - out_dtype = input.dtype - - # If input not quantized - # TODO(luka) remove this path if not used anymore - if input.dtype != current_platform.fp8_dtype(): - qinput, x_scale = self.quant_fp8( - input_2d, - input_scale, - input_scale_ub, - ) - else: - qinput, x_scale = input_2d, input_scale - - # Must have dim() conditions - # In per-token quant scenario, when the number of token is 1, - # the scale will only have 1 elements. - # Without checking the dim(), - # we cannot distingushes between per-tensor and per-token quant. - # Example: - # When the number of token is 1, per-token scale is [[1]] - # When per-tensor scale is [1] or (). - per_tensor_weights = weight_scale.numel() == 1 - per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2 - - # TODO(luka) do this dispatch during init (after ScaledMM refactor) - w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm( - self.preferred_backend, per_tensor_weights, per_tensor_activations - ) - - return w8a8_scaled_mm_func( - qinput=qinput, - weight=weight, - out_dtype=out_dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias, - output_shape=output_shape, - ) - - def normalize_e4m3fn_to_e4m3fnuz( weight: torch.Tensor, weight_scale: torch.Tensor,