diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 5b56437487477..829c47003ad0e 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -189,8 +189,34 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, def scaled_fp8_quant( input: torch.Tensor, scale: Optional[torch.Tensor] = None, + batch_dim_padding: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - output = torch.empty_like(input, dtype=torch.float8_e4m3fn) + """ + Quantize input tensor to FP8 and return quantized tensor and scale. + + This function supports both static and dynamic quantization: If you + provide the scale, it will use static scaling and if you omit it, + the scale will be determined dynamically. The function also allows + optional padding of the output tensor for downstream kernels that + will benefit from padding. + + Args: + input: The input tensor to be quantized to FP8 + scale: Optional scaling factor for the FP8 quantization + batch_dim_padding: If specified, pad the first dimension + of the output to at least this value. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + scaling factor. + """ + if batch_dim_padding: + shape = (max(batch_dim_padding, input.shape[0]), *input.shape[1:]) + output = torch.empty(shape, + device=input.device, + dtype=torch.float8_e4m3fn) + else: + output = torch.empty_like(input, dtype=torch.float8_e4m3fn) if scale is None: scale = torch.zeros(1, device=input.device, dtype=torch.float32) vllm_ops.dynamic_scaled_fp8_quant(output, input, scale) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index b57e1dde81a5f..ff996741c1d00 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -231,9 +231,14 @@ class Fp8LinearMethod(LinearMethodBase): # ops.scaled_fp8_quant supports both dynamic and static quant. # If dynamic, layer.act_scale is None and x_scale computed from x. # If static, layer.act_scale is scalar and x_scale set to act_scale. - qinput, x_scale = ops.scaled_fp8_quant(x, layer.act_scale) + qinput, x_scale = ops.scaled_fp8_quant(x, + layer.act_scale, + batch_dim_padding=17) - # Fused GEMM_DQ + # Fused GEMM_DQ -- note we padded the input above because + # torch._scaled_mm is more performant for matrices with + # batch dimension > 16. Note that this could change + # in the future. output, _ = torch._scaled_mm( qinput, layer.weight, @@ -243,7 +248,7 @@ class Fp8LinearMethod(LinearMethodBase): bias=bias, ) - return output + return torch.narrow(output, 0, 0, x.shape[0]) def all_close_1d(x: torch.Tensor) -> bool: