mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-01 21:04:11 +08:00
[Bugfix][ROCm][Dynamo][DS 3.1][FP8] fix unsupported hasattr call when Dynamo tracing for ROCm device (#31149)
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
This commit is contained in:
parent
3bb9561928
commit
dabff12ed3
@ -380,6 +380,31 @@ def _rocm_aiter_gemm_a8w8_fake(
|
||||
return Y
|
||||
|
||||
|
||||
def _rocm_aiter_triton_gemm_a8w8_blockscale_impl(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
|
||||
|
||||
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
|
||||
|
||||
|
||||
def _rocm_aiter_triton_gemm_a8w8_blockscale_fake(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
m = A.shape[0]
|
||||
n = B.shape[0]
|
||||
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
|
||||
return Y
|
||||
|
||||
|
||||
def _rocm_aiter_gemm_a8w8_blockscale_impl(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
@ -964,6 +989,12 @@ class rocm_aiter_ops:
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_triton_gemm_a8w8_blockscale",
|
||||
op_func=_rocm_aiter_triton_gemm_a8w8_blockscale_impl,
|
||||
fake_impl=_rocm_aiter_triton_gemm_a8w8_blockscale_fake,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_gemm_a8w8_blockscale",
|
||||
op_func=_rocm_aiter_gemm_a8w8_blockscale_impl,
|
||||
@ -1102,6 +1133,19 @@ class rocm_aiter_ops:
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.rocm_aiter_gemm_a8w8(A, B, As, Bs, bias, output_dtype)
|
||||
|
||||
@staticmethod
|
||||
def triton_gemm_a8w8_blockscale(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.rocm_aiter_triton_gemm_a8w8_blockscale(
|
||||
A, B, As, Bs, output_dtype
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def gemm_a8w8_blockscale(
|
||||
A: torch.Tensor,
|
||||
@ -1373,19 +1417,6 @@ class rocm_aiter_ops:
|
||||
config=config,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def triton_gemm_a8w8_blockscale(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
|
||||
|
||||
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
|
||||
|
||||
@staticmethod
|
||||
def group_fp8_quant(
|
||||
input_2d: torch.Tensor,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user