diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index e88e693a2dda..e444becd9867 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -971,6 +971,7 @@ steps: - vllm/model_executor/layers/layernorm.py - vllm/model_executor/layers/activation.py - vllm/model_executor/layers/quantization/input_quant_fp8.py + - vllm/model_executor/layers/fused_moe/layer.py - tests/compile/test_fusion_attn.py - tests/compile/test_silu_mul_quant_fusion.py - tests/compile/distributed/test_fusion_all_reduce.py diff --git a/tests/compile/distributed/test_fusions_e2e.py b/tests/compile/distributed/test_fusions_e2e.py index 661172e1965b..53c3f875d200 100644 --- a/tests/compile/distributed/test_fusions_e2e.py +++ b/tests/compile/distributed/test_fusions_e2e.py @@ -111,6 +111,17 @@ if current_platform.is_cuda(): async_tp=96, # MLP is MoE, half the fusions of dense ), ), + ModelBackendTestCase( + model_name="openai/gpt-oss-20b", + model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), + backend=AttentionBackendEnum.FLASHINFER, + matches=Matches( + attention_fusion=0, + allreduce_fusion=49, + sequence_parallel=49, + async_tp=48, + ), + ), ] elif current_platform.is_rocm(): diff --git a/vllm/distributed/device_communicators/symm_mem.py b/vllm/distributed/device_communicators/symm_mem.py index eb1f173b1192..7a049b003cf7 100644 --- a/vllm/distributed/device_communicators/symm_mem.py +++ b/vllm/distributed/device_communicators/symm_mem.py @@ -131,7 +131,7 @@ class SymmMemCommunicator: return None if out is None: out = torch.empty_like(inp) - self.buffer[: inp.numel()].copy_(inp.view(-1)) + self.buffer[: inp.numel()].copy_(inp.reshape(-1)) # Determine which algorithm to use use_multimem = False diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 0ef3130b2633..bb30f1292a5f 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1690,6 +1690,10 @@ class FusedMoE(CustomOp): ) def reduce_output(states: torch.Tensor) -> torch.Tensor: + # Slice before all_reduce to enable possible fusion + if self.hidden_size != og_hidden_states: + states = states[..., :og_hidden_states] + if ( not self.is_sequence_parallel and not self.use_dp_chunking @@ -1712,11 +1716,12 @@ class FusedMoE(CustomOp): if self.zero_expert_num is not None and self.zero_expert_num > 0: assert isinstance(fused_output, tuple) fused_output, zero_expert_result = fused_output - return (reduce_output(fused_output) + zero_expert_result)[ - ..., :og_hidden_states - ] + return ( + reduce_output(fused_output) + + zero_expert_result[..., :og_hidden_states] + ) else: - return reduce_output(fused_output)[..., :og_hidden_states] + return reduce_output(fused_output) else: if current_platform.is_tpu(): # TODO: Once the OOM issue for the TPU backend is resolved, we @@ -1729,8 +1734,8 @@ class FusedMoE(CustomOp): hidden_states, router_logits, self.layer_name ) return ( - reduce_output(shared_output)[..., :og_hidden_states], - reduce_output(fused_output)[..., :og_hidden_states], + reduce_output(shared_output), + reduce_output(fused_output), ) def forward_cuda(