mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 07:46:25 +08:00
[Compilation Bug] Fix Inductor Graph Output with Shape Issue (#24772)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
bc636f21a6
commit
3beadc2f25
@ -170,8 +170,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
return quant_config
|
return quant_config
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
# NOTE: hidden_states can have either 1D or 2D shape.
|
assert hidden_states.dim(
|
||||||
orig_shape = hidden_states.shape
|
) <= 2, "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs"
|
||||||
|
is_input_1d = hidden_states.dim() == 1
|
||||||
hidden_dim = hidden_states.shape[-1]
|
hidden_dim = hidden_states.shape[-1]
|
||||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||||
|
|
||||||
@ -180,7 +181,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
final_hidden_states = self.experts(hidden_states=hidden_states,
|
final_hidden_states = self.experts(hidden_states=hidden_states,
|
||||||
router_logits=router_logits)
|
router_logits=router_logits)
|
||||||
|
|
||||||
return final_hidden_states.view(orig_shape)
|
# return to 1d if input is 1d
|
||||||
|
return final_hidden_states.squeeze(0) if is_input_1d else \
|
||||||
|
final_hidden_states
|
||||||
|
|
||||||
|
|
||||||
class Qwen3MoeAttention(nn.Module):
|
class Qwen3MoeAttention(nn.Module):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user