[AMD][BugFix] Fix omission of wvSplitK kernel for small batch sizes (1-4) due to torch.compile (#21350)

Signed-off-by: Randall Smith <Randall.Smith@amd.com>
This commit is contained in:
rasmith 2025-07-28 14:38:20 -05:00 committed by GitHub
parent 01c753ed98
commit b361f14e39
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -8,6 +8,7 @@ import torch
from vllm import _custom_ops as ops
from vllm import envs
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
def get_token_bin_counts_and_mask(
@ -70,10 +71,10 @@ def default_unquantized_gemm(layer: torch.nn.Module,
return torch.nn.functional.linear(x, weight, bias)
def rocm_unquantized_gemm(layer: torch.nn.Module,
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None):
def rocm_unquantized_gemm_impl(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
from vllm.platforms.rocm import on_gfx9
k = weight.shape[1]
use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() and \
@ -97,6 +98,29 @@ def rocm_unquantized_gemm(layer: torch.nn.Module,
return torch.nn.functional.linear(x, weight, bias)
def rocm_unquantized_gemm_impl_fake(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return x.new_empty((*x.shape[:-1], weight.shape[0]))
def rocm_unquantized_gemm(layer: torch.nn.Module,
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.ops.vllm.rocm_unquantized_gemm_impl(x, weight, bias)
direct_register_custom_op(
op_name="rocm_unquantized_gemm_impl",
op_func=rocm_unquantized_gemm_impl,
mutates_args=[],
fake_impl=rocm_unquantized_gemm_impl_fake,
dispatch_key=current_platform.dispatch_key,
)
def cpu_unquantized_gemm(layer: torch.nn.Module,
x: torch.Tensor,
weight: torch.Tensor,