mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:54:56 +08:00
[Bugfix] Fix FA3 full cuda graph correctness (#19106)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
41aa578428
commit
b124e1085b
@ -320,6 +320,7 @@ steps:
|
||||
# these tests need to be separated, cannot combine
|
||||
- pytest -v -s compile/piecewise/test_simple.py
|
||||
- pytest -v -s compile/piecewise/test_toy_llama.py
|
||||
- pytest -v -s compile/piecewise/test_full_cudagraph.py
|
||||
|
||||
- label: PyTorch Fullgraph Test # 18min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
|
||||
@ -7,6 +7,7 @@ import pytest
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import CompilationConfig
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MODEL = "Qwen/Qwen2-1.5B-Instruct"
|
||||
|
||||
@ -37,7 +38,7 @@ def full_cudagraph_llm():
|
||||
"VLLM_FLASH_ATTN_VERSION": "3"
|
||||
}):
|
||||
return LLM(model=MODEL,
|
||||
gpu_memory_utilization=0.2,
|
||||
gpu_memory_utilization=0.3,
|
||||
compilation_config=CompilationConfig(full_cuda_graph=True))
|
||||
|
||||
|
||||
@ -48,7 +49,7 @@ def piecewise_llm():
|
||||
"VLLM_FLASH_ATTN_VERSION": "3"
|
||||
}):
|
||||
return LLM(model=MODEL,
|
||||
gpu_memory_utilization=0.5,
|
||||
gpu_memory_utilization=0.6,
|
||||
compilation_config=CompilationConfig())
|
||||
|
||||
|
||||
@ -61,6 +62,8 @@ def generate_text(llm: LLM, batch_size: int, max_tokens: int):
|
||||
return llm.generate(prompts, sampling_params)
|
||||
|
||||
|
||||
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
|
||||
reason="Only Hopper GPUs support FlashAttention 3")
|
||||
@pytest.mark.parametrize(("batch_size", "max_tokens"), [(1, 10), (7, 10),
|
||||
(16, 10), (25, 10),
|
||||
(32, 10), (45, 10),
|
||||
|
||||
@ -307,13 +307,14 @@ class FlashAttentionMetadataBuilder:
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_table = block_table
|
||||
|
||||
if get_flash_attn_version() == 3:
|
||||
self.aot_schedule = not compilation_config.full_cuda_graph
|
||||
if not self.aot_schedule:
|
||||
logger.warning(
|
||||
"AOT Schedule is disabled when using full_cuda_graph")
|
||||
else:
|
||||
self.aot_schedule = False
|
||||
self.aot_schedule = (get_flash_attn_version() == 3)
|
||||
self.use_full_cuda_graph = compilation_config.full_cuda_graph
|
||||
if self.use_full_cuda_graph and not self.aot_schedule:
|
||||
raise ValueError("Full CUDA graph mode requires AOT scheduling, "
|
||||
"which requires FlashAttention 3.")
|
||||
self.scheduler_metadata = torch.zeros(self.runner.max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
|
||||
# Sliding window size to be used with the AOT scheduler will be
|
||||
# populated on first build() call.
|
||||
@ -326,7 +327,7 @@ class FlashAttentionMetadataBuilder:
|
||||
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata):
|
||||
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
|
||||
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table = self.block_table
|
||||
@ -448,6 +449,18 @@ class FlashAttentionMetadataBuilder:
|
||||
max_seq_len=max_seq_len,
|
||||
causal=True)
|
||||
|
||||
if self.use_full_cuda_graph:
|
||||
assert scheduler_metadata is not None
|
||||
n = scheduler_metadata.shape[0]
|
||||
self.scheduler_metadata[:n].copy_(scheduler_metadata,
|
||||
non_blocking=True)
|
||||
# NOTE(woosuk): We should zero out the rest of the scheduler
|
||||
# metadata to guarantee the correctness. Otherwise, some thread
|
||||
# blocks may use the invalid scheduler metadata and overwrite the
|
||||
# output buffer.
|
||||
self.scheduler_metadata[n:] = 0
|
||||
scheduler_metadata = self.scheduler_metadata[:n]
|
||||
|
||||
attn_metadata = FlashAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
|
||||
@ -1750,6 +1750,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
attn_metadata: Optional[dict[str, Any]] = None
|
||||
else:
|
||||
query_start_loc = self.query_start_loc[:num_reqs + 1]
|
||||
# Make sure max_model_len is used at the graph capture time.
|
||||
self.seq_lens_np[:num_reqs] = self.max_model_len
|
||||
self.seq_lens_np[num_reqs:] = 0
|
||||
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
|
||||
non_blocking=True)
|
||||
seq_lens = self.seq_lens[:num_reqs]
|
||||
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user