diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index a7e0a00350e6..85429b3a01f9 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -170,8 +170,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module): return quant_config def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - # NOTE: hidden_states can have either 1D or 2D shape. - orig_shape = hidden_states.shape + assert hidden_states.dim( + ) <= 2, "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs" + is_input_1d = hidden_states.dim() == 1 hidden_dim = hidden_states.shape[-1] hidden_states = hidden_states.view(-1, hidden_dim) @@ -180,7 +181,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module): final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits) - return final_hidden_states.view(orig_shape) + # return to 1d if input is 1d + return final_hidden_states.squeeze(0) if is_input_1d else \ + final_hidden_states class Qwen3MoeAttention(nn.Module):