[Perf] Set split_k to 1 for triton_kernels (#30528)

Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
Xin Yang 2025-12-12 11:07:57 -08:00 committed by GitHub
parent cd7740ac5c
commit 1f19d8f899
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -57,12 +57,18 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps):
mx_axis=1, num_warps=num_warps
)
)
if current_platform.is_cuda() and current_platform.is_device_capability(100):
constraints = {
"is_persistent": True,
"epilogue_subtile": 1,
}
opt_flags.update_opt_flags_constraints(constraints)
if current_platform.is_cuda():
if current_platform.is_device_capability(90):
constraints = {
"split_k": 1,
}
opt_flags.update_opt_flags_constraints(constraints)
elif current_platform.is_device_capability(100):
constraints = {
"is_persistent": True,
"epilogue_subtile": 1,
}
opt_flags.update_opt_flags_constraints(constraints)
# transpose the tensor so that the quantization axis is on dim1
quant_tensor = quant_tensor.transpose(-2, -1)
scale = scale.transpose(-2, -1)