mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 14:15:01 +08:00
[Bugfix] fused_experts_impl wrong compute type for float32 (#11921)
Signed-off-by: shaochangxu.scx <shaochangxu.scx@antgroup.com> Co-authored-by: shaochangxu.scx <shaochangxu.scx@antgroup.com>
This commit is contained in:
parent
2118d0565c
commit
c32a7c7c0c
@ -701,8 +701,14 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
compute_type = (tl.bfloat16
|
||||
if hidden_states.dtype == torch.bfloat16 else tl.float16)
|
||||
if hidden_states.dtype == torch.bfloat16:
|
||||
compute_type = tl.bfloat16
|
||||
elif hidden_states.dtype == torch.float16:
|
||||
compute_type = tl.float16
|
||||
elif hidden_states.dtype == torch.float32:
|
||||
compute_type = tl.float32
|
||||
else:
|
||||
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
|
||||
|
||||
if inplace:
|
||||
out_hidden_states = hidden_states
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user