[BugFix] Make FP8 Linear compatible with torch.compile (#13918)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-02-26 13:48:55 -08:00 committed by GitHub
parent 4cb6fa0a9c
commit b382a7f28f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 21 additions and 4 deletions

View File

@ -369,12 +369,9 @@ class Fp8LinearMethod(LinearMethodBase):
size_k=layer.input_size_per_partition, size_k=layer.input_size_per_partition,
bias=bias) bias=bias)
# Note: lazy import to avoid triton import error.
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
apply_w8a8_block_fp8_linear)
if self.block_quant: if self.block_quant:
assert self.quant_config.weight_block_size is not None assert self.quant_config.weight_block_size is not None
return apply_w8a8_block_fp8_linear( return torch.ops.vllm.apply_w8a8_block_fp8_linear(
input=x, input=x,
weight=layer.weight, weight=layer.weight,
block_size=self.quant_config.weight_block_size, block_size=self.quant_config.weight_block_size,

View File

@ -17,6 +17,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED, CUTLASS_FP8_SUPPORTED, apply_fp8_linear) CUTLASS_BLOCK_FP8_SUPPORTED, CUTLASS_FP8_SUPPORTED, apply_fp8_linear)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
logger = init_logger(__name__) logger = init_logger(__name__)
@ -81,6 +82,25 @@ def apply_w8a8_block_fp8_linear(
return output.to(dtype=input.dtype).view(*output_shape) return output.to(dtype=input.dtype).view(*output_shape)
def apply_w8a8_block_fp8_linear_fake(
input: torch.Tensor,
weight: torch.Tensor,
block_size: List[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
output_shape = [*input.shape[:-1], weight.shape[0]]
return torch.empty(output_shape, dtype=input.dtype, device=input.device)
direct_register_custom_op(
op_name="apply_w8a8_block_fp8_linear",
op_func=apply_w8a8_block_fp8_linear,
mutates_args=[],
fake_impl=apply_w8a8_block_fp8_linear_fake,
)
# Unify the interface between `apply_w8a8_block_fp8_linear` and # Unify the interface between `apply_w8a8_block_fp8_linear` and
# `apply_fp8_linear` # `apply_fp8_linear`
# NOTE(lucas): this is quite messy, we should think through this more formally # NOTE(lucas): this is quite messy, we should think through this more formally