mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 20:35:26 +08:00
[Bugfix] Fix FA3 full cuda graph correctness (#19106)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
41aa578428
commit
b124e1085b
@ -320,6 +320,7 @@ steps:
|
|||||||
# these tests need to be separated, cannot combine
|
# these tests need to be separated, cannot combine
|
||||||
- pytest -v -s compile/piecewise/test_simple.py
|
- pytest -v -s compile/piecewise/test_simple.py
|
||||||
- pytest -v -s compile/piecewise/test_toy_llama.py
|
- pytest -v -s compile/piecewise/test_toy_llama.py
|
||||||
|
- pytest -v -s compile/piecewise/test_full_cudagraph.py
|
||||||
|
|
||||||
- label: PyTorch Fullgraph Test # 18min
|
- label: PyTorch Fullgraph Test # 18min
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import pytest
|
|||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.config import CompilationConfig
|
from vllm.config import CompilationConfig
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
MODEL = "Qwen/Qwen2-1.5B-Instruct"
|
MODEL = "Qwen/Qwen2-1.5B-Instruct"
|
||||||
|
|
||||||
@ -37,7 +38,7 @@ def full_cudagraph_llm():
|
|||||||
"VLLM_FLASH_ATTN_VERSION": "3"
|
"VLLM_FLASH_ATTN_VERSION": "3"
|
||||||
}):
|
}):
|
||||||
return LLM(model=MODEL,
|
return LLM(model=MODEL,
|
||||||
gpu_memory_utilization=0.2,
|
gpu_memory_utilization=0.3,
|
||||||
compilation_config=CompilationConfig(full_cuda_graph=True))
|
compilation_config=CompilationConfig(full_cuda_graph=True))
|
||||||
|
|
||||||
|
|
||||||
@ -48,7 +49,7 @@ def piecewise_llm():
|
|||||||
"VLLM_FLASH_ATTN_VERSION": "3"
|
"VLLM_FLASH_ATTN_VERSION": "3"
|
||||||
}):
|
}):
|
||||||
return LLM(model=MODEL,
|
return LLM(model=MODEL,
|
||||||
gpu_memory_utilization=0.5,
|
gpu_memory_utilization=0.6,
|
||||||
compilation_config=CompilationConfig())
|
compilation_config=CompilationConfig())
|
||||||
|
|
||||||
|
|
||||||
@ -61,6 +62,8 @@ def generate_text(llm: LLM, batch_size: int, max_tokens: int):
|
|||||||
return llm.generate(prompts, sampling_params)
|
return llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
|
||||||
|
reason="Only Hopper GPUs support FlashAttention 3")
|
||||||
@pytest.mark.parametrize(("batch_size", "max_tokens"), [(1, 10), (7, 10),
|
@pytest.mark.parametrize(("batch_size", "max_tokens"), [(1, 10), (7, 10),
|
||||||
(16, 10), (25, 10),
|
(16, 10), (25, 10),
|
||||||
(32, 10), (45, 10),
|
(32, 10), (45, 10),
|
||||||
|
|||||||
@ -307,13 +307,14 @@ class FlashAttentionMetadataBuilder:
|
|||||||
self.kv_cache_spec = kv_cache_spec
|
self.kv_cache_spec = kv_cache_spec
|
||||||
self.block_table = block_table
|
self.block_table = block_table
|
||||||
|
|
||||||
if get_flash_attn_version() == 3:
|
self.aot_schedule = (get_flash_attn_version() == 3)
|
||||||
self.aot_schedule = not compilation_config.full_cuda_graph
|
self.use_full_cuda_graph = compilation_config.full_cuda_graph
|
||||||
if not self.aot_schedule:
|
if self.use_full_cuda_graph and not self.aot_schedule:
|
||||||
logger.warning(
|
raise ValueError("Full CUDA graph mode requires AOT scheduling, "
|
||||||
"AOT Schedule is disabled when using full_cuda_graph")
|
"which requires FlashAttention 3.")
|
||||||
else:
|
self.scheduler_metadata = torch.zeros(self.runner.max_num_reqs + 1,
|
||||||
self.aot_schedule = False
|
dtype=torch.int32,
|
||||||
|
device=self.runner.device)
|
||||||
|
|
||||||
# Sliding window size to be used with the AOT scheduler will be
|
# Sliding window size to be used with the AOT scheduler will be
|
||||||
# populated on first build() call.
|
# populated on first build() call.
|
||||||
@ -326,7 +327,7 @@ class FlashAttentionMetadataBuilder:
|
|||||||
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
||||||
common_prefix_len: int,
|
common_prefix_len: int,
|
||||||
common_attn_metadata: CommonAttentionMetadata):
|
common_attn_metadata: CommonAttentionMetadata):
|
||||||
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
|
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
|
||||||
query_start_loc = common_attn_metadata.query_start_loc
|
query_start_loc = common_attn_metadata.query_start_loc
|
||||||
seq_lens = common_attn_metadata.seq_lens
|
seq_lens = common_attn_metadata.seq_lens
|
||||||
block_table = self.block_table
|
block_table = self.block_table
|
||||||
@ -448,6 +449,18 @@ class FlashAttentionMetadataBuilder:
|
|||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
causal=True)
|
causal=True)
|
||||||
|
|
||||||
|
if self.use_full_cuda_graph:
|
||||||
|
assert scheduler_metadata is not None
|
||||||
|
n = scheduler_metadata.shape[0]
|
||||||
|
self.scheduler_metadata[:n].copy_(scheduler_metadata,
|
||||||
|
non_blocking=True)
|
||||||
|
# NOTE(woosuk): We should zero out the rest of the scheduler
|
||||||
|
# metadata to guarantee the correctness. Otherwise, some thread
|
||||||
|
# blocks may use the invalid scheduler metadata and overwrite the
|
||||||
|
# output buffer.
|
||||||
|
self.scheduler_metadata[n:] = 0
|
||||||
|
scheduler_metadata = self.scheduler_metadata[:n]
|
||||||
|
|
||||||
attn_metadata = FlashAttentionMetadata(
|
attn_metadata = FlashAttentionMetadata(
|
||||||
num_actual_tokens=num_actual_tokens,
|
num_actual_tokens=num_actual_tokens,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
|
|||||||
@ -1750,6 +1750,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
attn_metadata: Optional[dict[str, Any]] = None
|
attn_metadata: Optional[dict[str, Any]] = None
|
||||||
else:
|
else:
|
||||||
query_start_loc = self.query_start_loc[:num_reqs + 1]
|
query_start_loc = self.query_start_loc[:num_reqs + 1]
|
||||||
|
# Make sure max_model_len is used at the graph capture time.
|
||||||
|
self.seq_lens_np[:num_reqs] = self.max_model_len
|
||||||
|
self.seq_lens_np[num_reqs:] = 0
|
||||||
|
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
|
||||||
|
non_blocking=True)
|
||||||
seq_lens = self.seq_lens[:num_reqs]
|
seq_lens = self.seq_lens[:num_reqs]
|
||||||
|
|
||||||
common_attn_metadata = CommonAttentionMetadata(
|
common_attn_metadata = CommonAttentionMetadata(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user