diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index f667864ca03f3..b00f0a24bfcb2 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -672,7 +672,7 @@ def pplx_moe( w1_scale_chunk = None w2_scale_chunk = None - if False and use_compile: + if use_compile: _fused_experts = torch.compile(fused_experts, backend='inductor', fullgraph=True) @@ -688,7 +688,7 @@ def pplx_moe( w2_scale=w2_scale_chunk, global_num_experts=num_experts) - if False and use_cudagraphs: #XXXXXXXXXXXX + if use_cudagraphs: out.fill_(0) stream = torch.cuda.Stream() graph = torch.cuda.CUDAGraph() diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 624e948a67395..a5a12732c5a83 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -635,6 +635,26 @@ def batched_moe_kernel_quantize_input( per_channel_quant: bool, block_shape: Optional[list[int]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + if (torch.compiler.is_compiling() + or torch.cuda.is_current_stream_capturing()): + # Note: this does a bunch of extra work because expert_num_tokens is ignored + # but it does support torch.compile + cudagraphs. + hidden_dim = A.size(-1) + if block_shape is not None: + block_shape = [block_shape[1], block_shape[0]] + assert A_scale is None or A_scale.dim() == 2 + A_q, A_q_scale = moe_kernel_quantize_input( + A.view(-1, hidden_dim), + A_scale, + qtype, + per_channel_quant, + block_shape) + A_q = A_q.view(E, -1, hidden_dim) + if A_q_scale is not None: + A_q_scale = A_q_scale.view(E, -1, A_q_scale.size(-1)) + return A_q, A_q_scale + + if qtype is not None: assert block_shape is not None A_q = torch.empty_like(A, dtype=qtype)