re-enable cudagraph+torch.compile

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
Bill Nell 2025-05-30 00:12:54 +00:00
parent 922165cba3
commit e69879996f
2 changed files with 22 additions and 2 deletions

View File

@ -672,7 +672,7 @@ def pplx_moe(
w1_scale_chunk = None
w2_scale_chunk = None
if False and use_compile:
if use_compile:
_fused_experts = torch.compile(fused_experts,
backend='inductor',
fullgraph=True)
@ -688,7 +688,7 @@ def pplx_moe(
w2_scale=w2_scale_chunk,
global_num_experts=num_experts)
if False and use_cudagraphs: #XXXXXXXXXXXX
if use_cudagraphs:
out.fill_(0)
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()

View File

@ -635,6 +635,26 @@ def batched_moe_kernel_quantize_input(
per_channel_quant: bool,
block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if (torch.compiler.is_compiling()
or torch.cuda.is_current_stream_capturing()):
# Note: this does a bunch of extra work because expert_num_tokens is ignored
# but it does support torch.compile + cudagraphs.
hidden_dim = A.size(-1)
if block_shape is not None:
block_shape = [block_shape[1], block_shape[0]]
assert A_scale is None or A_scale.dim() == 2
A_q, A_q_scale = moe_kernel_quantize_input(
A.view(-1, hidden_dim),
A_scale,
qtype,
per_channel_quant,
block_shape)
A_q = A_q.view(E, -1, hidden_dim)
if A_q_scale is not None:
A_q_scale = A_q_scale.view(E, -1, A_q_scale.size(-1))
return A_q, A_q_scale
if qtype is not None:
assert block_shape is not None
A_q = torch.empty_like(A, dtype=qtype)