mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 10:16:23 +08:00
[AMD][BugFix] Fix omission of wvSplitK kernel for small batch sizes (1-4) due to torch.compile (#21350)
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
This commit is contained in:
parent
01c753ed98
commit
b361f14e39
@ -8,6 +8,7 @@ import torch
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
|
||||
def get_token_bin_counts_and_mask(
|
||||
@ -70,10 +71,10 @@ def default_unquantized_gemm(layer: torch.nn.Module,
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
|
||||
def rocm_unquantized_gemm(layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None):
|
||||
def rocm_unquantized_gemm_impl(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
from vllm.platforms.rocm import on_gfx9
|
||||
k = weight.shape[1]
|
||||
use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() and \
|
||||
@ -97,6 +98,29 @@ def rocm_unquantized_gemm(layer: torch.nn.Module,
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
|
||||
def rocm_unquantized_gemm_impl_fake(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return x.new_empty((*x.shape[:-1], weight.shape[0]))
|
||||
|
||||
|
||||
def rocm_unquantized_gemm(layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return torch.ops.vllm.rocm_unquantized_gemm_impl(x, weight, bias)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_unquantized_gemm_impl",
|
||||
op_func=rocm_unquantized_gemm_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=rocm_unquantized_gemm_impl_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
def cpu_unquantized_gemm(layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user