mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-29 06:57:14 +08:00
[BugFix][AMD][Quantization] Fix torch.compile issue where wvSplitKQ not being called when it should when using quantized FP8 model (#22281)
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
This commit is contained in:
parent
0313cf854d
commit
cc7ae5e7ca
@ -13,6 +13,7 @@ from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||
@ -156,13 +157,10 @@ def cutlass_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor,
|
||||
return output.view(*output_shape)
|
||||
|
||||
|
||||
def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor,
|
||||
output_shape: list) -> torch.Tensor:
|
||||
def rocm_per_tensor_w8a8_scaled_mm_impl(
|
||||
qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor) -> torch.Tensor:
|
||||
from vllm.platforms.rocm import on_mi3xx
|
||||
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx(
|
||||
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0:
|
||||
@ -175,10 +173,38 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
bias=bias)
|
||||
return output
|
||||
|
||||
|
||||
def rocm_per_tensor_w8a8_scaled_mm_fake(
|
||||
qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor) -> torch.Tensor:
|
||||
return qinput.new_empty((*qinput.shape[:-1], weight.shape[1]),
|
||||
dtype=out_dtype)
|
||||
|
||||
|
||||
def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor,
|
||||
output_shape: list) -> torch.Tensor:
|
||||
output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl(
|
||||
qinput, weight, out_dtype, scale_a, scale_b, bias, input_2d)
|
||||
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_per_tensor_w8a8_scaled_mm_impl",
|
||||
op_func=rocm_per_tensor_w8a8_scaled_mm_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=rocm_per_tensor_w8a8_scaled_mm_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user