mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-19 07:25:01 +08:00
[AMD][BugFix] Fix omission of wvSplitK kernel for small batch sizes (1-4) due to torch.compile (#21350)
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
This commit is contained in:
parent
01c753ed98
commit
b361f14e39
@ -8,6 +8,7 @@ import torch
|
|||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
|
|
||||||
def get_token_bin_counts_and_mask(
|
def get_token_bin_counts_and_mask(
|
||||||
@ -70,10 +71,10 @@ def default_unquantized_gemm(layer: torch.nn.Module,
|
|||||||
return torch.nn.functional.linear(x, weight, bias)
|
return torch.nn.functional.linear(x, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
def rocm_unquantized_gemm(layer: torch.nn.Module,
|
def rocm_unquantized_gemm_impl(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor] = None):
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
from vllm.platforms.rocm import on_gfx9
|
from vllm.platforms.rocm import on_gfx9
|
||||||
k = weight.shape[1]
|
k = weight.shape[1]
|
||||||
use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() and \
|
use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() and \
|
||||||
@ -97,6 +98,29 @@ def rocm_unquantized_gemm(layer: torch.nn.Module,
|
|||||||
return torch.nn.functional.linear(x, weight, bias)
|
return torch.nn.functional.linear(x, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
|
def rocm_unquantized_gemm_impl_fake(
|
||||||
|
x: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
return x.new_empty((*x.shape[:-1], weight.shape[0]))
|
||||||
|
|
||||||
|
|
||||||
|
def rocm_unquantized_gemm(layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
return torch.ops.vllm.rocm_unquantized_gemm_impl(x, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="rocm_unquantized_gemm_impl",
|
||||||
|
op_func=rocm_unquantized_gemm_impl,
|
||||||
|
mutates_args=[],
|
||||||
|
fake_impl=rocm_unquantized_gemm_impl_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def cpu_unquantized_gemm(layer: torch.nn.Module,
|
def cpu_unquantized_gemm(layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user