From b361f14e394861e31f61ab980d15dc07e5c74290 Mon Sep 17 00:00:00 2001 From: rasmith Date: Mon, 28 Jul 2025 14:38:20 -0500 Subject: [PATCH] [AMD][BugFix] Fix omission of wvSplitK kernel for small batch sizes (1-4) due to torch.compile (#21350) Signed-off-by: Randall Smith --- vllm/model_executor/layers/utils.py | 32 +++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index ad4ba9c0b827a..cd32f12f3c269 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -8,6 +8,7 @@ import torch from vllm import _custom_ops as ops from vllm import envs from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op def get_token_bin_counts_and_mask( @@ -70,10 +71,10 @@ def default_unquantized_gemm(layer: torch.nn.Module, return torch.nn.functional.linear(x, weight, bias) -def rocm_unquantized_gemm(layer: torch.nn.Module, - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None): +def rocm_unquantized_gemm_impl( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: from vllm.platforms.rocm import on_gfx9 k = weight.shape[1] use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() and \ @@ -97,6 +98,29 @@ def rocm_unquantized_gemm(layer: torch.nn.Module, return torch.nn.functional.linear(x, weight, bias) +def rocm_unquantized_gemm_impl_fake( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return x.new_empty((*x.shape[:-1], weight.shape[0])) + + +def rocm_unquantized_gemm(layer: torch.nn.Module, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return torch.ops.vllm.rocm_unquantized_gemm_impl(x, weight, bias) + + +direct_register_custom_op( + op_name="rocm_unquantized_gemm_impl", + op_func=rocm_unquantized_gemm_impl, + mutates_args=[], + fake_impl=rocm_unquantized_gemm_impl_fake, + dispatch_key=current_platform.dispatch_key, +) + + def cpu_unquantized_gemm(layer: torch.nn.Module, x: torch.Tensor, weight: torch.Tensor,