mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-07 21:57:06 +08:00
remove FA changes
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
462c6b0b50
commit
90330563c6
@ -27,7 +27,7 @@ from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
|
||||
make_local_attention_virtual_batches, slice_query_start_locs)
|
||||
make_local_attention_virtual_batches)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
@ -172,27 +172,28 @@ class FlashAttentionMetadataBuilder(
|
||||
# populated on first build() call.
|
||||
self.aot_sliding_window: Optional[tuple[int, int]] = None
|
||||
|
||||
def build_slice(
|
||||
self,
|
||||
req_slice: slice,
|
||||
token_slice: slice,
|
||||
max_query_len: int,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
def build(
|
||||
self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata
|
||||
) -> FlashAttentionMetadata:
|
||||
num_reqs = req_slice.stop - req_slice.start
|
||||
num_tokens = token_slice.stop - token_slice.start
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
max_seq_len = int(self.runner.seq_lens_np[req_slice].max())
|
||||
query_start_loc = slice_query_start_locs(
|
||||
common_attn_metadata.query_start_loc, req_slice)
|
||||
seq_lens = common_attn_metadata.seq_lens[req_slice]
|
||||
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table = self.block_table
|
||||
block_table_tensor = block_table.get_device_tensor()[req_slice]
|
||||
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
|
||||
|
||||
block_table.slot_mapping[token_slice].copy_(
|
||||
block_table.slot_mapping_cpu[token_slice], non_blocking=True)
|
||||
slot_mapping = block_table.slot_mapping[token_slice]
|
||||
block_table.slot_mapping[:num_actual_tokens].copy_(
|
||||
block_table.slot_mapping_cpu[:num_actual_tokens],
|
||||
non_blocking=True)
|
||||
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
|
||||
# mode.
|
||||
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
|
||||
|
||||
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
|
||||
|
||||
if self.aot_sliding_window is None:
|
||||
self.aot_sliding_window = (-1, -1)
|
||||
@ -234,8 +235,8 @@ class FlashAttentionMetadataBuilder(
|
||||
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
|
||||
virt_block_table_tensor = make_local_attention_virtual_batches(
|
||||
self.runner.attention_chunk_size,
|
||||
query_start_loc,
|
||||
seq_lens,
|
||||
self.runner.query_start_loc_np[:num_reqs + 1],
|
||||
self.runner.seq_lens_np[:num_reqs],
|
||||
block_table_tensor,
|
||||
self.block_size,
|
||||
)
|
||||
@ -265,20 +266,20 @@ class FlashAttentionMetadataBuilder(
|
||||
use_cascade = common_prefix_len > 0
|
||||
|
||||
if use_cascade:
|
||||
cu_prefix_query_lens = torch.tensor([0, num_tokens],
|
||||
cu_prefix_query_lens = torch.tensor([0,num_actual_tokens],
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
prefix_kv_lens = torch.tensor([common_prefix_len],
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
suffix_kv_lens = (self.runner.seq_lens_np[req_slice] -
|
||||
suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] -
|
||||
common_prefix_len)
|
||||
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
|
||||
self.runner.device)
|
||||
prefix_scheduler_metadata = schedule(
|
||||
batch_size=1,
|
||||
cu_query_lens=cu_prefix_query_lens,
|
||||
max_query_len=num_tokens,
|
||||
max_query_len=num_actual_tokens,
|
||||
seqlens=prefix_kv_lens,
|
||||
max_seq_len=common_prefix_len,
|
||||
causal=False)
|
||||
@ -302,7 +303,7 @@ class FlashAttentionMetadataBuilder(
|
||||
causal=True)
|
||||
|
||||
attn_metadata = FlashAttentionMetadata(
|
||||
num_actual_tokens=num_tokens,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
@ -320,28 +321,13 @@ class FlashAttentionMetadataBuilder(
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
def build(
|
||||
self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata
|
||||
) -> FlashAttentionMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
return self.build_slice(
|
||||
req_slice=slice(0, num_reqs),
|
||||
token_slice=slice(0, num_actual_tokens),
|
||||
max_query_len=max_query_len,
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
||||
# Full CUDA Graph always supported (FA2 support checked separately)
|
||||
return True
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
return False #use_cascade_attention(*args, **kwargs)
|
||||
return use_cascade_attention(*args, **kwargs)
|
||||
|
||||
|
||||
class FlashAttentionImpl(AttentionImpl):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user