remove FA changes

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-07-08 19:06:17 +00:00
parent 462c6b0b50
commit 90330563c6

View File

@ -27,7 +27,7 @@ from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
make_local_attention_virtual_batches, slice_query_start_locs)
make_local_attention_virtual_batches)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
@ -172,27 +172,28 @@ class FlashAttentionMetadataBuilder(
# populated on first build() call.
self.aot_sliding_window: Optional[tuple[int, int]] = None
def build_slice(
self,
req_slice: slice,
token_slice: slice,
max_query_len: int,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
def build(
self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata
) -> FlashAttentionMetadata:
num_reqs = req_slice.stop - req_slice.start
num_tokens = token_slice.stop - token_slice.start
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
max_seq_len = int(self.runner.seq_lens_np[req_slice].max())
query_start_loc = slice_query_start_locs(
common_attn_metadata.query_start_loc, req_slice)
seq_lens = common_attn_metadata.seq_lens[req_slice]
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table = self.block_table
block_table_tensor = block_table.get_device_tensor()[req_slice]
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
block_table.slot_mapping[token_slice].copy_(
block_table.slot_mapping_cpu[token_slice], non_blocking=True)
slot_mapping = block_table.slot_mapping[token_slice]
block_table.slot_mapping[:num_actual_tokens].copy_(
block_table.slot_mapping_cpu[:num_actual_tokens],
non_blocking=True)
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
# mode.
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
if self.aot_sliding_window is None:
self.aot_sliding_window = (-1, -1)
@ -234,8 +235,8 @@ class FlashAttentionMetadataBuilder(
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
virt_block_table_tensor = make_local_attention_virtual_batches(
self.runner.attention_chunk_size,
query_start_loc,
seq_lens,
self.runner.query_start_loc_np[:num_reqs + 1],
self.runner.seq_lens_np[:num_reqs],
block_table_tensor,
self.block_size,
)
@ -265,20 +266,20 @@ class FlashAttentionMetadataBuilder(
use_cascade = common_prefix_len > 0
if use_cascade:
cu_prefix_query_lens = torch.tensor([0, num_tokens],
cu_prefix_query_lens = torch.tensor([0,num_actual_tokens],
dtype=torch.int32,
device=self.runner.device)
prefix_kv_lens = torch.tensor([common_prefix_len],
dtype=torch.int32,
device=self.runner.device)
suffix_kv_lens = (self.runner.seq_lens_np[req_slice] -
suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] -
common_prefix_len)
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
self.runner.device)
prefix_scheduler_metadata = schedule(
batch_size=1,
cu_query_lens=cu_prefix_query_lens,
max_query_len=num_tokens,
max_query_len=num_actual_tokens,
seqlens=prefix_kv_lens,
max_seq_len=common_prefix_len,
causal=False)
@ -302,7 +303,7 @@ class FlashAttentionMetadataBuilder(
causal=True)
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_tokens,
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
@ -320,28 +321,13 @@ class FlashAttentionMetadataBuilder(
)
return attn_metadata
def build(
self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata
) -> FlashAttentionMetadata:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
return self.build_slice(
req_slice=slice(0, num_reqs),
token_slice=slice(0, num_actual_tokens),
max_query_len=max_query_len,
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
)
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
# Full CUDA Graph always supported (FA2 support checked separately)
return True
def use_cascade_attention(self, *args, **kwargs) -> bool:
return False #use_cascade_attention(*args, **kwargs)
return use_cascade_attention(*args, **kwargs)
class FlashAttentionImpl(AttentionImpl):