mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:45:00 +08:00
[BugFix] Make FP8 Linear compatible with torch.compile (#13918)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
4cb6fa0a9c
commit
b382a7f28f
@ -369,12 +369,9 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
size_k=layer.input_size_per_partition,
|
size_k=layer.input_size_per_partition,
|
||||||
bias=bias)
|
bias=bias)
|
||||||
|
|
||||||
# Note: lazy import to avoid triton import error.
|
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
|
||||||
apply_w8a8_block_fp8_linear)
|
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
assert self.quant_config.weight_block_size is not None
|
assert self.quant_config.weight_block_size is not None
|
||||||
return apply_w8a8_block_fp8_linear(
|
return torch.ops.vllm.apply_w8a8_block_fp8_linear(
|
||||||
input=x,
|
input=x,
|
||||||
weight=layer.weight,
|
weight=layer.weight,
|
||||||
block_size=self.quant_config.weight_block_size,
|
block_size=self.quant_config.weight_block_size,
|
||||||
|
|||||||
@ -17,6 +17,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
CUTLASS_BLOCK_FP8_SUPPORTED, CUTLASS_FP8_SUPPORTED, apply_fp8_linear)
|
CUTLASS_BLOCK_FP8_SUPPORTED, CUTLASS_FP8_SUPPORTED, apply_fp8_linear)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -81,6 +82,25 @@ def apply_w8a8_block_fp8_linear(
|
|||||||
return output.to(dtype=input.dtype).view(*output_shape)
|
return output.to(dtype=input.dtype).view(*output_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_w8a8_block_fp8_linear_fake(
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
block_size: List[int],
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
input_scale: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||||
|
return torch.empty(output_shape, dtype=input.dtype, device=input.device)
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="apply_w8a8_block_fp8_linear",
|
||||||
|
op_func=apply_w8a8_block_fp8_linear,
|
||||||
|
mutates_args=[],
|
||||||
|
fake_impl=apply_w8a8_block_fp8_linear_fake,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Unify the interface between `apply_w8a8_block_fp8_linear` and
|
# Unify the interface between `apply_w8a8_block_fp8_linear` and
|
||||||
# `apply_fp8_linear`
|
# `apply_fp8_linear`
|
||||||
# NOTE(lucas): this is quite messy, we should think through this more formally
|
# NOTE(lucas): this is quite messy, we should think through this more formally
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user