mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-12 06:37:03 +08:00
re-enable cudagraph+torch.compile
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
parent
922165cba3
commit
e69879996f
@ -672,7 +672,7 @@ def pplx_moe(
|
||||
w1_scale_chunk = None
|
||||
w2_scale_chunk = None
|
||||
|
||||
if False and use_compile:
|
||||
if use_compile:
|
||||
_fused_experts = torch.compile(fused_experts,
|
||||
backend='inductor',
|
||||
fullgraph=True)
|
||||
@ -688,7 +688,7 @@ def pplx_moe(
|
||||
w2_scale=w2_scale_chunk,
|
||||
global_num_experts=num_experts)
|
||||
|
||||
if False and use_cudagraphs: #XXXXXXXXXXXX
|
||||
if use_cudagraphs:
|
||||
out.fill_(0)
|
||||
stream = torch.cuda.Stream()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
|
||||
@ -635,6 +635,26 @@ def batched_moe_kernel_quantize_input(
|
||||
per_channel_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if (torch.compiler.is_compiling()
|
||||
or torch.cuda.is_current_stream_capturing()):
|
||||
# Note: this does a bunch of extra work because expert_num_tokens is ignored
|
||||
# but it does support torch.compile + cudagraphs.
|
||||
hidden_dim = A.size(-1)
|
||||
if block_shape is not None:
|
||||
block_shape = [block_shape[1], block_shape[0]]
|
||||
assert A_scale is None or A_scale.dim() == 2
|
||||
A_q, A_q_scale = moe_kernel_quantize_input(
|
||||
A.view(-1, hidden_dim),
|
||||
A_scale,
|
||||
qtype,
|
||||
per_channel_quant,
|
||||
block_shape)
|
||||
A_q = A_q.view(E, -1, hidden_dim)
|
||||
if A_q_scale is not None:
|
||||
A_q_scale = A_q_scale.view(E, -1, A_q_scale.size(-1))
|
||||
return A_q, A_q_scale
|
||||
|
||||
|
||||
if qtype is not None:
|
||||
assert block_shape is not None
|
||||
A_q = torch.empty_like(A, dtype=qtype)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user