mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-02 10:04:28 +08:00
[Bugfix][ROCm][Dynamo][DS 3.1][FP8] fix unsupported hasattr call when Dynamo tracing for ROCm device (#31149)
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
This commit is contained in:
parent
3bb9561928
commit
dabff12ed3
@ -380,6 +380,31 @@ def _rocm_aiter_gemm_a8w8_fake(
|
|||||||
return Y
|
return Y
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_triton_gemm_a8w8_blockscale_impl(
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
output_dtype: torch.dtype = torch.float16,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
|
||||||
|
|
||||||
|
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_triton_gemm_a8w8_blockscale_fake(
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
output_dtype: torch.dtype = torch.float16,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
m = A.shape[0]
|
||||||
|
n = B.shape[0]
|
||||||
|
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
|
||||||
|
return Y
|
||||||
|
|
||||||
|
|
||||||
def _rocm_aiter_gemm_a8w8_blockscale_impl(
|
def _rocm_aiter_gemm_a8w8_blockscale_impl(
|
||||||
A: torch.Tensor,
|
A: torch.Tensor,
|
||||||
B: torch.Tensor,
|
B: torch.Tensor,
|
||||||
@ -964,6 +989,12 @@ class rocm_aiter_ops:
|
|||||||
dispatch_key=current_platform.dispatch_key,
|
dispatch_key=current_platform.dispatch_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="rocm_aiter_triton_gemm_a8w8_blockscale",
|
||||||
|
op_func=_rocm_aiter_triton_gemm_a8w8_blockscale_impl,
|
||||||
|
fake_impl=_rocm_aiter_triton_gemm_a8w8_blockscale_fake,
|
||||||
|
)
|
||||||
|
|
||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="rocm_aiter_gemm_a8w8_blockscale",
|
op_name="rocm_aiter_gemm_a8w8_blockscale",
|
||||||
op_func=_rocm_aiter_gemm_a8w8_blockscale_impl,
|
op_func=_rocm_aiter_gemm_a8w8_blockscale_impl,
|
||||||
@ -1102,6 +1133,19 @@ class rocm_aiter_ops:
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return torch.ops.vllm.rocm_aiter_gemm_a8w8(A, B, As, Bs, bias, output_dtype)
|
return torch.ops.vllm.rocm_aiter_gemm_a8w8(A, B, As, Bs, bias, output_dtype)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def triton_gemm_a8w8_blockscale(
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
block_size: list[int],
|
||||||
|
output_dtype: torch.dtype = torch.float16,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.ops.vllm.rocm_aiter_triton_gemm_a8w8_blockscale(
|
||||||
|
A, B, As, Bs, output_dtype
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def gemm_a8w8_blockscale(
|
def gemm_a8w8_blockscale(
|
||||||
A: torch.Tensor,
|
A: torch.Tensor,
|
||||||
@ -1373,19 +1417,6 @@ class rocm_aiter_ops:
|
|||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def triton_gemm_a8w8_blockscale(
|
|
||||||
A: torch.Tensor,
|
|
||||||
B: torch.Tensor,
|
|
||||||
As: torch.Tensor,
|
|
||||||
Bs: torch.Tensor,
|
|
||||||
block_size: list[int],
|
|
||||||
output_dtype: torch.dtype = torch.float16,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
|
|
||||||
|
|
||||||
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def group_fp8_quant(
|
def group_fp8_quant(
|
||||||
input_2d: torch.Tensor,
|
input_2d: torch.Tensor,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user