From b382a7f28f739f3b120e5495fd029089d0399428 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 26 Feb 2025 13:48:55 -0800 Subject: [PATCH] [BugFix] Make FP8 Linear compatible with torch.compile (#13918) Signed-off-by: Woosuk Kwon --- .../model_executor/layers/quantization/fp8.py | 5 +---- .../layers/quantization/utils/fp8_utils.py | 20 +++++++++++++++++++ 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 76a7d4df8a36..a705f63be4ac 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -369,12 +369,9 @@ class Fp8LinearMethod(LinearMethodBase): size_k=layer.input_size_per_partition, bias=bias) - # Note: lazy import to avoid triton import error. - from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - apply_w8a8_block_fp8_linear) if self.block_quant: assert self.quant_config.weight_block_size is not None - return apply_w8a8_block_fp8_linear( + return torch.ops.vllm.apply_w8a8_block_fp8_linear( input=x, weight=layer.weight, block_size=self.quant_config.weight_block_size, diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 61706f485f46..7d91d2cf1c6e 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -17,6 +17,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( CUTLASS_BLOCK_FP8_SUPPORTED, CUTLASS_FP8_SUPPORTED, apply_fp8_linear) from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op logger = init_logger(__name__) @@ -81,6 +82,25 @@ def apply_w8a8_block_fp8_linear( return output.to(dtype=input.dtype).view(*output_shape) +def apply_w8a8_block_fp8_linear_fake( + input: torch.Tensor, + weight: torch.Tensor, + block_size: List[int], + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + output_shape = [*input.shape[:-1], weight.shape[0]] + return torch.empty(output_shape, dtype=input.dtype, device=input.device) + + +direct_register_custom_op( + op_name="apply_w8a8_block_fp8_linear", + op_func=apply_w8a8_block_fp8_linear, + mutates_args=[], + fake_impl=apply_w8a8_block_fp8_linear_fake, +) + + # Unify the interface between `apply_w8a8_block_fp8_linear` and # `apply_fp8_linear` # NOTE(lucas): this is quite messy, we should think through this more formally