mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-06 17:42:16 +08:00
re-enable cudagraph+torch.compile
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
parent
922165cba3
commit
e69879996f
@ -672,7 +672,7 @@ def pplx_moe(
|
|||||||
w1_scale_chunk = None
|
w1_scale_chunk = None
|
||||||
w2_scale_chunk = None
|
w2_scale_chunk = None
|
||||||
|
|
||||||
if False and use_compile:
|
if use_compile:
|
||||||
_fused_experts = torch.compile(fused_experts,
|
_fused_experts = torch.compile(fused_experts,
|
||||||
backend='inductor',
|
backend='inductor',
|
||||||
fullgraph=True)
|
fullgraph=True)
|
||||||
@ -688,7 +688,7 @@ def pplx_moe(
|
|||||||
w2_scale=w2_scale_chunk,
|
w2_scale=w2_scale_chunk,
|
||||||
global_num_experts=num_experts)
|
global_num_experts=num_experts)
|
||||||
|
|
||||||
if False and use_cudagraphs: #XXXXXXXXXXXX
|
if use_cudagraphs:
|
||||||
out.fill_(0)
|
out.fill_(0)
|
||||||
stream = torch.cuda.Stream()
|
stream = torch.cuda.Stream()
|
||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
|
|||||||
@ -635,6 +635,26 @@ def batched_moe_kernel_quantize_input(
|
|||||||
per_channel_quant: bool,
|
per_channel_quant: bool,
|
||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
if (torch.compiler.is_compiling()
|
||||||
|
or torch.cuda.is_current_stream_capturing()):
|
||||||
|
# Note: this does a bunch of extra work because expert_num_tokens is ignored
|
||||||
|
# but it does support torch.compile + cudagraphs.
|
||||||
|
hidden_dim = A.size(-1)
|
||||||
|
if block_shape is not None:
|
||||||
|
block_shape = [block_shape[1], block_shape[0]]
|
||||||
|
assert A_scale is None or A_scale.dim() == 2
|
||||||
|
A_q, A_q_scale = moe_kernel_quantize_input(
|
||||||
|
A.view(-1, hidden_dim),
|
||||||
|
A_scale,
|
||||||
|
qtype,
|
||||||
|
per_channel_quant,
|
||||||
|
block_shape)
|
||||||
|
A_q = A_q.view(E, -1, hidden_dim)
|
||||||
|
if A_q_scale is not None:
|
||||||
|
A_q_scale = A_q_scale.view(E, -1, A_q_scale.size(-1))
|
||||||
|
return A_q, A_q_scale
|
||||||
|
|
||||||
|
|
||||||
if qtype is not None:
|
if qtype is not None:
|
||||||
assert block_shape is not None
|
assert block_shape is not None
|
||||||
A_q = torch.empty_like(A, dtype=qtype)
|
A_q = torch.empty_like(A, dtype=qtype)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user