diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 8ab96b3b7ac3c..4ee6b499b5396 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -320,6 +320,7 @@ steps: # these tests need to be separated, cannot combine - pytest -v -s compile/piecewise/test_simple.py - pytest -v -s compile/piecewise/test_toy_llama.py + - pytest -v -s compile/piecewise/test_full_cudagraph.py - label: PyTorch Fullgraph Test # 18min mirror_hardwares: [amdexperimental, amdproduction] diff --git a/tests/compile/piecewise/test_full_cudagraph.py b/tests/compile/piecewise/test_full_cudagraph.py index 3188ea40f9ee6..134bade486079 100644 --- a/tests/compile/piecewise/test_full_cudagraph.py +++ b/tests/compile/piecewise/test_full_cudagraph.py @@ -7,6 +7,7 @@ import pytest from vllm import LLM, SamplingParams from vllm.config import CompilationConfig +from vllm.platforms import current_platform MODEL = "Qwen/Qwen2-1.5B-Instruct" @@ -37,7 +38,7 @@ def full_cudagraph_llm(): "VLLM_FLASH_ATTN_VERSION": "3" }): return LLM(model=MODEL, - gpu_memory_utilization=0.2, + gpu_memory_utilization=0.3, compilation_config=CompilationConfig(full_cuda_graph=True)) @@ -48,7 +49,7 @@ def piecewise_llm(): "VLLM_FLASH_ATTN_VERSION": "3" }): return LLM(model=MODEL, - gpu_memory_utilization=0.5, + gpu_memory_utilization=0.6, compilation_config=CompilationConfig()) @@ -61,6 +62,8 @@ def generate_text(llm: LLM, batch_size: int, max_tokens: int): return llm.generate(prompts, sampling_params) +@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0), + reason="Only Hopper GPUs support FlashAttention 3") @pytest.mark.parametrize(("batch_size", "max_tokens"), [(1, 10), (7, 10), (16, 10), (25, 10), (32, 10), (45, 10), diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index a92c51883af1c..a9f748d026f4b 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -307,13 +307,14 @@ class FlashAttentionMetadataBuilder: self.kv_cache_spec = kv_cache_spec self.block_table = block_table - if get_flash_attn_version() == 3: - self.aot_schedule = not compilation_config.full_cuda_graph - if not self.aot_schedule: - logger.warning( - "AOT Schedule is disabled when using full_cuda_graph") - else: - self.aot_schedule = False + self.aot_schedule = (get_flash_attn_version() == 3) + self.use_full_cuda_graph = compilation_config.full_cuda_graph + if self.use_full_cuda_graph and not self.aot_schedule: + raise ValueError("Full CUDA graph mode requires AOT scheduling, " + "which requires FlashAttention 3.") + self.scheduler_metadata = torch.zeros(self.runner.max_num_reqs + 1, + dtype=torch.int32, + device=self.runner.device) # Sliding window size to be used with the AOT scheduler will be # populated on first build() call. @@ -326,7 +327,7 @@ class FlashAttentionMetadataBuilder: def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata): - max_seq_len = self.runner.seq_lens_np[:num_reqs].max() + max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max()) query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens block_table = self.block_table @@ -448,6 +449,18 @@ class FlashAttentionMetadataBuilder: max_seq_len=max_seq_len, causal=True) + if self.use_full_cuda_graph: + assert scheduler_metadata is not None + n = scheduler_metadata.shape[0] + self.scheduler_metadata[:n].copy_(scheduler_metadata, + non_blocking=True) + # NOTE(woosuk): We should zero out the rest of the scheduler + # metadata to guarantee the correctness. Otherwise, some thread + # blocks may use the invalid scheduler metadata and overwrite the + # output buffer. + self.scheduler_metadata[n:] = 0 + scheduler_metadata = self.scheduler_metadata[:n] + attn_metadata = FlashAttentionMetadata( num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9ac33a1499610..4a67e37781bf6 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1750,6 +1750,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): attn_metadata: Optional[dict[str, Any]] = None else: query_start_loc = self.query_start_loc[:num_reqs + 1] + # Make sure max_model_len is used at the graph capture time. + self.seq_lens_np[:num_reqs] = self.max_model_len + self.seq_lens_np[num_reqs:] = 0 + self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], + non_blocking=True) seq_lens = self.seq_lens[:num_reqs] common_attn_metadata = CommonAttentionMetadata(