mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 17:16:20 +08:00
[backends][short_conv] CUDA graph piecewise edits (#24215)
Signed-off-by: Paul Pak <paulpak58@gmail.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
7faf51f1cc
commit
c6f384dafd
@ -115,7 +115,7 @@ class ShortConv(MambaBase, CustomOp):
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
has_initial_states_p = attn_metadata.has_initial_states
|
||||
has_initial_states_p = attn_metadata.has_initial_states_p
|
||||
|
||||
BCx, _ = self.in_proj(hidden_states)
|
||||
|
||||
|
||||
@ -6,12 +6,12 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
from vllm.v1.attention.backends.mamba_attn import (
|
||||
BaseMambaAttentionMetadataBuilder)
|
||||
from vllm.v1.attention.backends.utils import (PAD_SLOT_ID,
|
||||
CommonAttentionMetadata,
|
||||
compute_causal_conv1d_metadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
|
||||
|
||||
class ShortConvAttentionBackend(AttentionBackend):
|
||||
@ -29,8 +29,8 @@ class ShortConvAttentionMetadata:
|
||||
num_decode_tokens: int
|
||||
|
||||
query_start_loc: torch.Tensor
|
||||
has_initial_states: torch.Tensor
|
||||
state_indices_tensor: torch.Tensor # shape: [batch,]
|
||||
state_indices_tensor: torch.Tensor
|
||||
has_initial_states_p: Optional[torch.Tensor]
|
||||
|
||||
# For causal_conv1d
|
||||
nums_dict: Optional[dict] = None
|
||||
@ -39,14 +39,7 @@ class ShortConvAttentionMetadata:
|
||||
|
||||
|
||||
class ShortConvAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[ShortConvAttentionMetadata]):
|
||||
|
||||
reorder_batch_threshold: int = 1
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
BaseMambaAttentionMetadataBuilder[ShortConvAttentionMetadata]):
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
@ -54,7 +47,6 @@ class ShortConvAttentionMetadataBuilder(
|
||||
fast_build: bool = False) -> ShortConvAttentionMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||
|
||||
# for causal_conv1d
|
||||
@ -64,13 +56,13 @@ class ShortConvAttentionMetadataBuilder(
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold))
|
||||
has_initial_states = None
|
||||
|
||||
has_initial_states_p = None
|
||||
if num_prefills > 0:
|
||||
#[batch,]
|
||||
has_initial_states_cpu = (
|
||||
common_attn_metadata.
|
||||
num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0)
|
||||
has_initial_states = has_initial_states_cpu.to(
|
||||
has_initial_states_p = has_initial_states_cpu.to(
|
||||
query_start_loc.device)
|
||||
|
||||
query_start_loc_p = common_attn_metadata.query_start_loc[
|
||||
@ -79,14 +71,22 @@ class ShortConvAttentionMetadataBuilder(
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = \
|
||||
compute_causal_conv1d_metadata(query_start_loc_p)
|
||||
|
||||
elif (num_decodes > 0 and num_decodes <= self.decode_cudagraph_max_bs
|
||||
and self.compilation_config.full_cuda_graph):
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
|
||||
self.state_indices_tensor[:num_decodes].copy_(state_indices_tensor,
|
||||
non_blocking=True)
|
||||
state_indices_tensor = self.state_indices_tensor[:num_input_tokens]
|
||||
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
|
||||
|
||||
attn_metadata = ShortConvAttentionMetadata(
|
||||
query_start_loc=query_start_loc,
|
||||
state_indices_tensor=state_indices_tensor,
|
||||
has_initial_states_p=has_initial_states_p,
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
query_start_loc=query_start_loc,
|
||||
has_initial_states=has_initial_states,
|
||||
state_indices_tensor=state_indices_tensor,
|
||||
nums_dict=nums_dict,
|
||||
batch_ptr=batch_ptr,
|
||||
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user