[Bugfix][ROCm][Dynamo][DS 3.1][FP8] fix unsupported hasattr call when Dynamo tracing for ROCm device (#31149)

Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
This commit is contained in:
zejunchen-zejun 2025-12-24 13:32:19 +08:00 committed by GitHub
parent 3bb9561928
commit dabff12ed3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -380,6 +380,31 @@ def _rocm_aiter_gemm_a8w8_fake(
return Y
def _rocm_aiter_triton_gemm_a8w8_blockscale_impl(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
def _rocm_aiter_triton_gemm_a8w8_blockscale_fake(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
m = A.shape[0]
n = B.shape[0]
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
return Y
def _rocm_aiter_gemm_a8w8_blockscale_impl(
A: torch.Tensor,
B: torch.Tensor,
@ -964,6 +989,12 @@ class rocm_aiter_ops:
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_triton_gemm_a8w8_blockscale",
op_func=_rocm_aiter_triton_gemm_a8w8_blockscale_impl,
fake_impl=_rocm_aiter_triton_gemm_a8w8_blockscale_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_gemm_a8w8_blockscale",
op_func=_rocm_aiter_gemm_a8w8_blockscale_impl,
@ -1102,6 +1133,19 @@ class rocm_aiter_ops:
) -> torch.Tensor:
return torch.ops.vllm.rocm_aiter_gemm_a8w8(A, B, As, Bs, bias, output_dtype)
@staticmethod
def triton_gemm_a8w8_blockscale(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
return torch.ops.vllm.rocm_aiter_triton_gemm_a8w8_blockscale(
A, B, As, Bs, output_dtype
)
@staticmethod
def gemm_a8w8_blockscale(
A: torch.Tensor,
@ -1373,19 +1417,6 @@ class rocm_aiter_ops:
config=config,
)
@staticmethod
def triton_gemm_a8w8_blockscale(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
@staticmethod
def group_fp8_quant(
input_2d: torch.Tensor,