[Bugfix] Fix precision corruption when shared_experts_stream=None (#28942)

Signed-off-by: zhyajie <yajizhan@amd.com>
Co-authored-by: zhyajie <yajizhan@amd.com>
This commit is contained in:
杰兮 2025-11-20 03:30:57 +08:00 committed by GitHub
parent fe69f331f8
commit 9d2d561257
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 6 deletions

View File

@ -371,8 +371,8 @@ class FusedMoE(CustomOp):
logger.info_once("Disabling MoE shared_experts cuda stream")
self.shared_experts_stream = None
else:
# TODO(rob): enable shared expert overlap with non-cuda.
# aux_stream() returns None on non-cuda platforms.
# TODO(rob): enable shared expert overlap with non-cuda-alike.
# aux_stream() returns None on non-cuda-alike platforms.
self.shared_experts_stream = aux_stream()
if self.shared_experts_stream is not None:
logger.info_once("Enabled separate cuda stream for MoE shared_experts")
@ -1865,6 +1865,11 @@ class FusedMoE(CustomOp):
hidden_states_combined, router_logits = get_ep_group().dispatch(
hidden_states, router_logits, self.is_sequence_parallel
)
# Run shared experts before matrix multiply.
# because matrix multiply maybe modify the hidden_states.
if has_separate_shared_experts and not use_shared_experts_stream:
assert self.shared_experts is not None
shared_output = self.shared_experts(hidden_states)
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
@ -1908,8 +1913,6 @@ class FusedMoE(CustomOp):
# conflict with the main stream
shared_output = self.shared_experts(hidden_states_clone)
current_stream().wait_stream(self.shared_experts_stream)
else:
shared_output = self.shared_experts(hidden_states)
final_hidden_states = (
shared_output,

View File

@ -426,8 +426,7 @@ def aux_stream() -> torch.cuda.Stream | None:
from vllm.platforms import current_platform
# TODO: validate this works properly on ROCm platform.
if _aux_stream is None and current_platform.is_cuda():
if _aux_stream is None and current_platform.is_cuda_alike():
_aux_stream = torch.cuda.Stream()
return _aux_stream