[Bugfix] Fix FA3 full cuda graph correctness (#19106)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-06-03 23:10:15 -07:00 committed by GitHub
parent 41aa578428
commit b124e1085b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 32 additions and 10 deletions

View File

@ -320,6 +320,7 @@ steps:
# these tests need to be separated, cannot combine
- pytest -v -s compile/piecewise/test_simple.py
- pytest -v -s compile/piecewise/test_toy_llama.py
- pytest -v -s compile/piecewise/test_full_cudagraph.py
- label: PyTorch Fullgraph Test # 18min
mirror_hardwares: [amdexperimental, amdproduction]

View File

@ -7,6 +7,7 @@ import pytest
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig
from vllm.platforms import current_platform
MODEL = "Qwen/Qwen2-1.5B-Instruct"
@ -37,7 +38,7 @@ def full_cudagraph_llm():
"VLLM_FLASH_ATTN_VERSION": "3"
}):
return LLM(model=MODEL,
gpu_memory_utilization=0.2,
gpu_memory_utilization=0.3,
compilation_config=CompilationConfig(full_cuda_graph=True))
@ -48,7 +49,7 @@ def piecewise_llm():
"VLLM_FLASH_ATTN_VERSION": "3"
}):
return LLM(model=MODEL,
gpu_memory_utilization=0.5,
gpu_memory_utilization=0.6,
compilation_config=CompilationConfig())
@ -61,6 +62,8 @@ def generate_text(llm: LLM, batch_size: int, max_tokens: int):
return llm.generate(prompts, sampling_params)
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
reason="Only Hopper GPUs support FlashAttention 3")
@pytest.mark.parametrize(("batch_size", "max_tokens"), [(1, 10), (7, 10),
(16, 10), (25, 10),
(32, 10), (45, 10),

View File

@ -307,13 +307,14 @@ class FlashAttentionMetadataBuilder:
self.kv_cache_spec = kv_cache_spec
self.block_table = block_table
if get_flash_attn_version() == 3:
self.aot_schedule = not compilation_config.full_cuda_graph
if not self.aot_schedule:
logger.warning(
"AOT Schedule is disabled when using full_cuda_graph")
else:
self.aot_schedule = False
self.aot_schedule = (get_flash_attn_version() == 3)
self.use_full_cuda_graph = compilation_config.full_cuda_graph
if self.use_full_cuda_graph and not self.aot_schedule:
raise ValueError("Full CUDA graph mode requires AOT scheduling, "
"which requires FlashAttention 3.")
self.scheduler_metadata = torch.zeros(self.runner.max_num_reqs + 1,
dtype=torch.int32,
device=self.runner.device)
# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
@ -326,7 +327,7 @@ class FlashAttentionMetadataBuilder:
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table = self.block_table
@ -448,6 +449,18 @@ class FlashAttentionMetadataBuilder:
max_seq_len=max_seq_len,
causal=True)
if self.use_full_cuda_graph:
assert scheduler_metadata is not None
n = scheduler_metadata.shape[0]
self.scheduler_metadata[:n].copy_(scheduler_metadata,
non_blocking=True)
# NOTE(woosuk): We should zero out the rest of the scheduler
# metadata to guarantee the correctness. Otherwise, some thread
# blocks may use the invalid scheduler metadata and overwrite the
# output buffer.
self.scheduler_metadata[n:] = 0
scheduler_metadata = self.scheduler_metadata[:n]
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,

View File

@ -1750,6 +1750,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
attn_metadata: Optional[dict[str, Any]] = None
else:
query_start_loc = self.query_start_loc[:num_reqs + 1]
# Make sure max_model_len is used at the graph capture time.
self.seq_lens_np[:num_reqs] = self.max_model_len
self.seq_lens_np[num_reqs:] = 0
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
non_blocking=True)
seq_lens = self.seq_lens[:num_reqs]
common_attn_metadata = CommonAttentionMetadata(