[Compilation Bug] Fix Inductor Graph Output with Shape Issue (#24772)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-09-12 17:23:05 -04:00 committed by GitHub
parent bc636f21a6
commit 3beadc2f25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -170,8 +170,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
return quant_config
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
assert hidden_states.dim(
) <= 2, "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs"
is_input_1d = hidden_states.dim() == 1
hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim)
@ -180,7 +181,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
return final_hidden_states.view(orig_shape)
# return to 1d if input is 1d
return final_hidden_states.squeeze(0) if is_input_1d else \
final_hidden_states
class Qwen3MoeAttention(nn.Module):