[Perf] Apply torch.compile for per_block_cast_to_fp8 (#24611)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-09-22 21:42:45 -04:00 committed by yewentao256
parent dbb029cfe1
commit e6c22d2b2f

View File

@ -135,7 +135,7 @@ DEFAULT_BLOCK_SIZE = [128, 128]
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/dd6ed14acbc7445dcef224248a77ab4d22b5f240/deep_gemm/utils/math.py#L38
# TODO(wentao): optimize this function, using triton or cuda kernel
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def per_block_cast_to_fp8(
x: torch.Tensor,
block_size: list[int] = DEFAULT_BLOCK_SIZE,
@ -187,4 +187,4 @@ __all__ = [
"is_deep_gemm_e8m0_used",
"is_deep_gemm_supported",
"should_use_deepgemm_for_fp8_linear",
]
]