[Bugfix] Fix Stream Sync for Shared Expert Overlap (#28430)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Robert Shaw <robertgshaw2@gmail.com>
Co-authored-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
This commit is contained in:
Robert Shaw 2025-11-11 00:59:08 -05:00 committed by GitHub
parent bca74e32b7
commit e605e8e323
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 33 deletions

View File

@ -3,6 +3,3 @@ accuracy_threshold: 0.45
num_questions: 1319
num_fewshot: 5
max_model_len: 4096
# Duo stream incompatabilbe with this model: https://github.com/vllm-project/vllm/issues/28220
env:
VLLM_DISABLE_SHARED_EXPERTS_STREAM: "1"

View File

@ -2456,28 +2456,6 @@ class FusedMoE(CustomOp):
staged_hidden_states.copy_(hidden_states, non_blocking=True)
staged_router_logits.copy_(router_logits, non_blocking=True)
# If there are shared experts but we are not using a modular kernel,
# the shared experts must be called here
if has_separate_shared_experts:
assert self.shared_experts is not None
if self.shared_experts_stream is not None:
# For chunked, we start the shared experts stream here
# (Note that no concurrency with the router/gate)
self.shared_experts_stream.wait_stream(current_stream())
with torch.cuda.stream(self.shared_experts_stream):
# Note that staged_hidden_states clone() is necessary
# here to avoid conflict with the main stream
shared_output = self.shared_experts(
staged_hidden_states.clone()
)
else:
shared_output = self.shared_experts(staged_hidden_states)
else:
shared_output = None
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
@ -2506,11 +2484,7 @@ class FusedMoE(CustomOp):
if has_separate_shared_experts:
assert not isinstance(final_hidden_states, tuple)
assert self.shared_experts is not None
# Here we finish the shared experts stream
if self.shared_experts_stream is not None:
current_stream().wait_stream(self.shared_experts_stream)
shared_output = self.shared_experts(staged_hidden_states)
final_hidden_states = (
shared_output,
final_hidden_states,
@ -2619,11 +2593,22 @@ class FusedMoE(CustomOp):
assert self.shared_experts is not None
if self.shared_experts_stream is not None:
# Clone BEFORE switching streams to avoid race condition
# where routed_expert kernel may mutate hidden_states.
hidden_states_clone = hidden_states.clone()
self.shared_experts_stream.wait_stream(current_stream())
# Run shared experts in parallel on a separate stream
with torch.cuda.stream(self.shared_experts_stream):
# Note that hidden_states clone() is necessary here to avoid
# conflict with the main stream
shared_output = self.shared_experts(hidden_states.clone())
shared_output = self.shared_experts(hidden_states_clone)
# Record that the clone will be used by shared_experts_stream
# to avoid gc issue from deallocation of hidden_states_clone
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
# NOTE: we dont need shared_output.record_stream(current_stream())
# because we synch the streams before using shared_output.
hidden_states_clone.record_stream(self.shared_experts_stream)
else:
shared_output = self.shared_experts(hidden_states)
else: