[Bugfix] fused_experts_impl wrong compute type for float32 (#11921)

Signed-off-by: shaochangxu.scx <shaochangxu.scx@antgroup.com>
Co-authored-by: shaochangxu.scx <shaochangxu.scx@antgroup.com>
This commit is contained in:
shaochangxu 2025-01-11 13:49:39 +08:00 committed by GitHub
parent 2118d0565c
commit c32a7c7c0c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -701,8 +701,14 @@ def fused_experts_impl(hidden_states: torch.Tensor,
device=hidden_states.device,
dtype=hidden_states.dtype)
compute_type = (tl.bfloat16
if hidden_states.dtype == torch.bfloat16 else tl.float16)
if hidden_states.dtype == torch.bfloat16:
compute_type = tl.bfloat16
elif hidden_states.dtype == torch.float16:
compute_type = tl.float16
elif hidden_states.dtype == torch.float32:
compute_type = tl.float32
else:
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
if inplace:
out_hidden_states = hidden_states