diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index ad4ba9c0b827a..cd32f12f3c269 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -8,6 +8,7 @@ import torch from vllm import _custom_ops as ops from vllm import envs from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op def get_token_bin_counts_and_mask( @@ -70,10 +71,10 @@ def default_unquantized_gemm(layer: torch.nn.Module, return torch.nn.functional.linear(x, weight, bias) -def rocm_unquantized_gemm(layer: torch.nn.Module, - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None): +def rocm_unquantized_gemm_impl( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: from vllm.platforms.rocm import on_gfx9 k = weight.shape[1] use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() and \ @@ -97,6 +98,29 @@ def rocm_unquantized_gemm(layer: torch.nn.Module, return torch.nn.functional.linear(x, weight, bias) +def rocm_unquantized_gemm_impl_fake( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return x.new_empty((*x.shape[:-1], weight.shape[0])) + + +def rocm_unquantized_gemm(layer: torch.nn.Module, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return torch.ops.vllm.rocm_unquantized_gemm_impl(x, weight, bias) + + +direct_register_custom_op( + op_name="rocm_unquantized_gemm_impl", + op_func=rocm_unquantized_gemm_impl, + mutates_args=[], + fake_impl=rocm_unquantized_gemm_impl_fake, + dispatch_key=current_platform.dispatch_key, +) + + def cpu_unquantized_gemm(layer: torch.nn.Module, x: torch.Tensor, weight: torch.Tensor,