mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 13:44:28 +08:00
[PERF] Qwen3-next MTP speedup (change bool mask indexing to index_select / index_copy to reduce d2h) (#26437)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
This commit is contained in:
parent
f6cdc9a02f
commit
785d8b6410
@ -45,7 +45,7 @@ def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]
|
||||
"""
|
||||
|
||||
cache_entries: tuple[tuple | None, dict | None, Any] = []
|
||||
cache_size = 4
|
||||
cache_size = 8
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
|
||||
@ -423,7 +423,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
(query, key),
|
||||
)
|
||||
value = rearrange(value, "l (h d) -> 1 l h d", d=self.head_v_dim)
|
||||
return query, key, value
|
||||
return query.contiguous(), key.contiguous(), value.contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -455,7 +455,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
spec_query_start_loc = attn_metadata.spec_query_start_loc
|
||||
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
|
||||
spec_sequence_masks = attn_metadata.spec_sequence_masks
|
||||
spec_token_masks = attn_metadata.spec_token_masks
|
||||
spec_token_indx = attn_metadata.spec_token_indx
|
||||
non_spec_token_indx = attn_metadata.non_spec_token_indx
|
||||
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501
|
||||
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
@ -463,8 +464,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
ssm_state = self_kv_cache[1]
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
num_accepted_tokens = attn_metadata.num_accepted_tokens
|
||||
if spec_token_masks is not None:
|
||||
spec_token_masks = spec_token_masks[:num_actual_tokens]
|
||||
|
||||
# 1. Set up dimensions for reshapes later
|
||||
projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states[:num_actual_tokens])
|
||||
@ -487,8 +486,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
mixed_qkv_spec = mixed_qkv
|
||||
mixed_qkv_non_spec = None
|
||||
else:
|
||||
mixed_qkv_spec = mixed_qkv[spec_token_masks]
|
||||
mixed_qkv_non_spec = mixed_qkv[~spec_token_masks]
|
||||
mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx)
|
||||
mixed_qkv_non_spec = mixed_qkv.index_select(0, non_spec_token_indx)
|
||||
else:
|
||||
mixed_qkv_spec = None
|
||||
mixed_qkv_non_spec = mixed_qkv
|
||||
@ -558,10 +557,10 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
g_non_spec = None
|
||||
beta_non_spec = None
|
||||
else:
|
||||
g_spec = g[:, spec_token_masks]
|
||||
beta_spec = beta[:, spec_token_masks]
|
||||
g_non_spec = g[:, ~spec_token_masks]
|
||||
beta_non_spec = beta[:, ~spec_token_masks]
|
||||
g_spec = g.index_select(1, spec_token_indx)
|
||||
beta_spec = beta.index_select(1, spec_token_indx)
|
||||
g_non_spec = g.index_select(1, non_spec_token_indx)
|
||||
beta_non_spec = beta.index_select(1, non_spec_token_indx)
|
||||
else:
|
||||
g_spec = None
|
||||
beta_spec = None
|
||||
@ -638,8 +637,9 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
dtype=core_attn_out_non_spec.dtype,
|
||||
device=core_attn_out_non_spec.device,
|
||||
)
|
||||
core_attn_out[:, spec_token_masks] = core_attn_out_spec
|
||||
core_attn_out[:, ~spec_token_masks] = core_attn_out_non_spec
|
||||
core_attn_out.index_copy_(1, spec_token_indx, core_attn_out_spec)
|
||||
core_attn_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec)
|
||||
|
||||
elif spec_sequence_masks is not None:
|
||||
core_attn_out = core_attn_out_spec
|
||||
else:
|
||||
|
||||
@ -47,9 +47,9 @@ class GDNAttentionMetadata:
|
||||
None # shape: [batch - num_spec_decodes,]
|
||||
)
|
||||
spec_sequence_masks: torch.Tensor | None = None # shape: [batch,]
|
||||
spec_token_masks: torch.Tensor | None = (
|
||||
None # shape: [num_prefill_tokens + num_decode_tokens,]
|
||||
)
|
||||
spec_token_indx: torch.Tensor | None = None
|
||||
non_spec_token_indx: torch.Tensor | None = None
|
||||
|
||||
num_accepted_tokens: torch.Tensor | None = None # shape: [batch,]
|
||||
|
||||
# The following attributes are for triton implementation of causal_conv1d
|
||||
@ -105,9 +105,14 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
dtype=torch.bool,
|
||||
device=device,
|
||||
)
|
||||
self.spec_token_masks = torch.empty(
|
||||
self.spec_token_indx = torch.empty(
|
||||
(self.decode_cudagraph_max_bs * (self.num_spec + 1),),
|
||||
dtype=torch.bool,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.non_spec_token_indx = torch.empty(
|
||||
(self.decode_cudagraph_max_bs * (self.num_spec + 1),),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.spec_query_start_loc = torch.empty(
|
||||
@ -166,7 +171,8 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
split_decodes_and_prefills(m, decode_threshold=1)
|
||||
)
|
||||
num_spec_decode_tokens = 0
|
||||
spec_token_masks = None
|
||||
spec_token_indx = None
|
||||
non_spec_token_indx = None
|
||||
spec_state_indices_tensor = None
|
||||
non_spec_state_indices_tensor = m.block_table_tensor[:, 0]
|
||||
spec_query_start_loc = None
|
||||
@ -180,18 +186,23 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
num_prefills = non_spec_query_lens.size(0) - num_decodes
|
||||
num_decode_tokens = num_decodes
|
||||
num_prefill_tokens = non_spec_query_lens.sum().item() - num_decode_tokens
|
||||
num_spec_decode_tokens = (
|
||||
query_lens.sum().item() - num_prefill_tokens - num_decode_tokens
|
||||
)
|
||||
|
||||
if num_prefills == 0 and num_decodes == 0:
|
||||
spec_token_masks = torch.ones(
|
||||
(
|
||||
min(
|
||||
num_spec_decodes * (self.num_spec + 1),
|
||||
query_start_loc[-1].item(),
|
||||
)
|
||||
),
|
||||
dtype=torch.bool,
|
||||
spec_token_size = min(
|
||||
num_spec_decodes * (self.num_spec + 1),
|
||||
query_start_loc[-1].item(),
|
||||
)
|
||||
spec_token_indx = torch.arange(
|
||||
spec_token_size,
|
||||
dtype=torch.int32,
|
||||
device=query_start_loc.device,
|
||||
)
|
||||
non_spec_token_indx = torch.empty(
|
||||
0, dtype=torch.int32, device=query_start_loc.device
|
||||
)
|
||||
spec_state_indices_tensor = m.block_table_tensor[:, : self.num_spec + 1]
|
||||
non_spec_state_indices_tensor = None
|
||||
spec_query_start_loc = query_start_loc
|
||||
@ -200,6 +211,11 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
spec_token_masks = torch.repeat_interleave(
|
||||
spec_sequence_masks, query_lens
|
||||
)
|
||||
index = torch.argsort(spec_token_masks)
|
||||
num_non_spec_tokens = num_prefill_tokens + num_decode_tokens
|
||||
non_spec_token_indx = index[:num_non_spec_tokens]
|
||||
spec_token_indx = index[num_non_spec_tokens:]
|
||||
|
||||
spec_state_indices_tensor = m.block_table_tensor[
|
||||
spec_sequence_masks, : self.num_spec + 1
|
||||
]
|
||||
@ -226,9 +242,6 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
out=non_spec_query_start_loc[1:],
|
||||
)
|
||||
|
||||
num_spec_decode_tokens = (
|
||||
query_lens.sum().item() - num_prefill_tokens - num_decode_tokens
|
||||
)
|
||||
assert num_accepted_tokens is not None
|
||||
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]
|
||||
|
||||
@ -274,12 +287,18 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
spec_sequence_masks = self.spec_sequence_masks[:batch_size]
|
||||
spec_sequence_masks[num_spec_decodes:].fill_(False)
|
||||
|
||||
assert spec_token_masks is not None
|
||||
self.spec_token_masks[: spec_token_masks.size(0)].copy_(
|
||||
spec_token_masks, non_blocking=True
|
||||
assert non_spec_token_indx is not None and spec_token_indx is not None
|
||||
self.non_spec_token_indx[: non_spec_token_indx.size(0)].copy_(
|
||||
non_spec_token_indx, non_blocking=True
|
||||
)
|
||||
spec_token_masks = self.spec_token_masks[:num_actual_tokens]
|
||||
spec_token_masks[spec_token_masks.size(0) :].fill_(False)
|
||||
non_spec_token_indx = self.non_spec_token_indx[
|
||||
: non_spec_token_indx.size(0)
|
||||
]
|
||||
|
||||
self.spec_token_indx[: spec_token_indx.size(0)].copy_(
|
||||
spec_token_indx, non_blocking=True
|
||||
)
|
||||
spec_token_indx = self.spec_token_indx[: spec_token_indx.size(0)]
|
||||
|
||||
self.spec_query_start_loc[: num_spec_decodes + 1].copy_(
|
||||
spec_query_start_loc, non_blocking=True
|
||||
@ -332,7 +351,8 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
spec_state_indices_tensor=spec_state_indices_tensor,
|
||||
non_spec_state_indices_tensor=non_spec_state_indices_tensor,
|
||||
spec_sequence_masks=spec_sequence_masks,
|
||||
spec_token_masks=spec_token_masks,
|
||||
spec_token_indx=spec_token_indx,
|
||||
non_spec_token_indx=non_spec_token_indx,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
nums_dict=nums_dict,
|
||||
batch_ptr=batch_ptr,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user