mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:15:26 +08:00
[Bugfix] Fix GPT-OSS AR+NORM fusion (#28841)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
This commit is contained in:
parent
ef1f7030f0
commit
6330f9477d
@ -971,6 +971,7 @@ steps:
|
|||||||
- vllm/model_executor/layers/layernorm.py
|
- vllm/model_executor/layers/layernorm.py
|
||||||
- vllm/model_executor/layers/activation.py
|
- vllm/model_executor/layers/activation.py
|
||||||
- vllm/model_executor/layers/quantization/input_quant_fp8.py
|
- vllm/model_executor/layers/quantization/input_quant_fp8.py
|
||||||
|
- vllm/model_executor/layers/fused_moe/layer.py
|
||||||
- tests/compile/test_fusion_attn.py
|
- tests/compile/test_fusion_attn.py
|
||||||
- tests/compile/test_silu_mul_quant_fusion.py
|
- tests/compile/test_silu_mul_quant_fusion.py
|
||||||
- tests/compile/distributed/test_fusion_all_reduce.py
|
- tests/compile/distributed/test_fusion_all_reduce.py
|
||||||
|
|||||||
@ -111,6 +111,17 @@ if current_platform.is_cuda():
|
|||||||
async_tp=96, # MLP is MoE, half the fusions of dense
|
async_tp=96, # MLP is MoE, half the fusions of dense
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
ModelBackendTestCase(
|
||||||
|
model_name="openai/gpt-oss-20b",
|
||||||
|
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
|
||||||
|
backend=AttentionBackendEnum.FLASHINFER,
|
||||||
|
matches=Matches(
|
||||||
|
attention_fusion=0,
|
||||||
|
allreduce_fusion=49,
|
||||||
|
sequence_parallel=49,
|
||||||
|
async_tp=48,
|
||||||
|
),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
elif current_platform.is_rocm():
|
elif current_platform.is_rocm():
|
||||||
|
|||||||
@ -131,7 +131,7 @@ class SymmMemCommunicator:
|
|||||||
return None
|
return None
|
||||||
if out is None:
|
if out is None:
|
||||||
out = torch.empty_like(inp)
|
out = torch.empty_like(inp)
|
||||||
self.buffer[: inp.numel()].copy_(inp.view(-1))
|
self.buffer[: inp.numel()].copy_(inp.reshape(-1))
|
||||||
|
|
||||||
# Determine which algorithm to use
|
# Determine which algorithm to use
|
||||||
use_multimem = False
|
use_multimem = False
|
||||||
|
|||||||
@ -1690,6 +1690,10 @@ class FusedMoE(CustomOp):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def reduce_output(states: torch.Tensor) -> torch.Tensor:
|
def reduce_output(states: torch.Tensor) -> torch.Tensor:
|
||||||
|
# Slice before all_reduce to enable possible fusion
|
||||||
|
if self.hidden_size != og_hidden_states:
|
||||||
|
states = states[..., :og_hidden_states]
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not self.is_sequence_parallel
|
not self.is_sequence_parallel
|
||||||
and not self.use_dp_chunking
|
and not self.use_dp_chunking
|
||||||
@ -1712,11 +1716,12 @@ class FusedMoE(CustomOp):
|
|||||||
if self.zero_expert_num is not None and self.zero_expert_num > 0:
|
if self.zero_expert_num is not None and self.zero_expert_num > 0:
|
||||||
assert isinstance(fused_output, tuple)
|
assert isinstance(fused_output, tuple)
|
||||||
fused_output, zero_expert_result = fused_output
|
fused_output, zero_expert_result = fused_output
|
||||||
return (reduce_output(fused_output) + zero_expert_result)[
|
return (
|
||||||
..., :og_hidden_states
|
reduce_output(fused_output)
|
||||||
]
|
+ zero_expert_result[..., :og_hidden_states]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return reduce_output(fused_output)[..., :og_hidden_states]
|
return reduce_output(fused_output)
|
||||||
else:
|
else:
|
||||||
if current_platform.is_tpu():
|
if current_platform.is_tpu():
|
||||||
# TODO: Once the OOM issue for the TPU backend is resolved, we
|
# TODO: Once the OOM issue for the TPU backend is resolved, we
|
||||||
@ -1729,8 +1734,8 @@ class FusedMoE(CustomOp):
|
|||||||
hidden_states, router_logits, self.layer_name
|
hidden_states, router_logits, self.layer_name
|
||||||
)
|
)
|
||||||
return (
|
return (
|
||||||
reduce_output(shared_output)[..., :og_hidden_states],
|
reduce_output(shared_output),
|
||||||
reduce_output(fused_output)[..., :og_hidden_states],
|
reduce_output(fused_output),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_cuda(
|
def forward_cuda(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user