[Kernel] Flashinfer for prefill & decode, with Cudagraph support for decode (#4628)

Co-authored-by: LiuXiaoxuanPKU <llilyliupku@gmail.com>, bong-furiosa <bongwon.jang@furiosa.ai>
This commit is contained in:
Lily Liu 2024-06-28 15:28:49 -07:00 committed by GitHub
parent 6a62cb82cc
commit 7041de4384
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 313 additions and 117 deletions

View File

@ -211,3 +211,6 @@ steps:
- pytest -v -s distributed/test_custom_all_reduce.py - pytest -v -s distributed/test_custom_all_reduce.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.5/flashinfer-0.0.5+cu121torch2.3-cp310-cp310-linux_x86_64.whl
- VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Meta-Llama-3-8B DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py

View File

@ -19,4 +19,4 @@ sentence-transformers # required for embedding
aiohttp aiohttp
# quantization # quantization
bitsandbytes==0.42.0 bitsandbytes==0.42.0

View File

@ -2,7 +2,6 @@
Run `pytest tests/basic_correctness/test_basic_correctness.py`. Run `pytest tests/basic_correctness/test_basic_correctness.py`.
""" """
import os
import weakref import weakref
import pytest import pytest
@ -13,7 +12,6 @@ MODELS = [
"facebook/opt-125m", "facebook/opt-125m",
"meta-llama/Llama-2-7b-hf", "meta-llama/Llama-2-7b-hf",
] ]
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"
def test_vllm_gc_ed(): def test_vllm_gc_ed():
@ -39,10 +37,6 @@ def test_models(
max_tokens: int, max_tokens: int,
enforce_eager: bool, enforce_eager: bool,
) -> None: ) -> None:
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
if backend_by_env_var == "FLASHINFER" and enforce_eager is False:
pytest.skip("Skipping non-eager test for FlashInferBackend.")
with hf_runner(model, dtype=dtype) as hf_model: with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)

View File

@ -21,7 +21,6 @@ MODELS = [
os.environ["TEST_DIST_MODEL"], os.environ["TEST_DIST_MODEL"],
] ]
DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND" DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND"
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"
@pytest.mark.skipif(torch.cuda.device_count() < 2, @pytest.mark.skipif(torch.cuda.device_count() < 2,
@ -39,16 +38,12 @@ def test_models(
) -> None: ) -> None:
distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND) distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND)
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
enforce_eager = backend_by_env_var == "FLASHINFER"
with hf_runner(model, dtype=dtype) as hf_model: with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
with vllm_runner(model, with vllm_runner(model,
dtype=dtype, dtype=dtype,
tensor_parallel_size=2, tensor_parallel_size=2,
enforce_eager=enforce_eager,
distributed_executor_backend=distributed_executor_backend distributed_executor_backend=distributed_executor_backend
) as vllm_model: ) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

View File

@ -1,10 +1,16 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Type from typing import Any, Dict, List, Optional, Set, Tuple, Type
import flashinfer try:
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
from vllm_flash_attn import flash_attn_varlen_func
except ImportError:
flash_attn_varlen_func = None
BatchDecodeWithPagedKVCacheWrapper = None
BatchPrefillWithPagedKVCacheWrapper = None
import torch import torch
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from vllm_flash_attn import flash_attn_varlen_func
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
@ -60,19 +66,16 @@ class FlashInferMetadata(AttentionMetadata):
# requests only. # requests only.
max_prefill_seq_len: int max_prefill_seq_len: int
use_cuda_graph: bool = False use_cuda_graph: bool = True
prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
# Metadata for the prefill stage since we still # Metadata for the prefill stage
# use flash attention for prefill.
seq_start_loc: Optional[torch.Tensor] = None seq_start_loc: Optional[torch.Tensor] = None
query_start_loc: Optional[torch.Tensor] = None
block_tables: Optional[torch.Tensor] = None block_tables: Optional[torch.Tensor] = None
# Metadata for the decode stage
# Workspace buffer required by the kernel, the buffer should not
# be allocated/deacollated by the FalshInfermetadata object.
workspace_buffer: Optional[torch.Tensor] = None
# An example for paged_kv_indices, paged_kv_indptr: # An example for paged_kv_indices, paged_kv_indptr:
# request 1, page indices [0, 5, 8] # request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7] # request 2, page indices [1, 6, 7]
@ -98,6 +101,7 @@ class FlashInferMetadata(AttentionMetadata):
page_size: Optional[int] = None page_size: Optional[int] = None
# The data type of the paged kv cache # The data type of the paged kv cache
data_type: torch.dtype = None data_type: torch.dtype = None
device: torch.device = torch.device("cuda")
def __post_init__(self): def __post_init__(self):
# Refer to # Refer to
@ -109,13 +113,35 @@ class FlashInferMetadata(AttentionMetadata):
f"Only {supported_head_sizes} are supported for head_dim,", f"Only {supported_head_sizes} are supported for head_dim,",
f"received {self.head_dim}.") f"received {self.head_dim}.")
# When using flashinfer, we are also creating the FlashInferMetadata, def begin_forward(self):
# which will also call post_init by default, here we want to skip the if self.num_prefill_tokens > 0:
# post_init if it's the prefill phase. if self.paged_kv_indices is None:
if self.num_prefills == 0: return
assert self.num_decode_tokens > 0
self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( assert self.prefill_wrapper is not None
self.workspace_buffer, "NHD") assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None
assert self.paged_kv_last_page_len is not None
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device)
self.prefill_wrapper.begin_forward(
self.query_start_loc, self.paged_kv_indptr,
self.paged_kv_indices, self.paged_kv_last_page_len,
self.num_qo_heads, self.num_kv_heads, self.head_dim,
self.page_size)
else:
if not self.use_cuda_graph:
assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None
assert self.paged_kv_last_page_len is not None
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device)
assert self.decode_wrapper is not None
self.decode_wrapper.begin_forward( self.decode_wrapper.begin_forward(
self.paged_kv_indptr, self.paged_kv_indptr,
self.paged_kv_indices, self.paged_kv_indices,
@ -133,8 +159,9 @@ class FlashInferMetadata(AttentionMetadata):
) -> Dict[str, Any]: ) -> Dict[str, Any]:
if skip_fields is None: if skip_fields is None:
skip_fields = set() skip_fields = set()
# We need to skip the decode_wrapper field since it cannot be # We need to skip the prefill/decode_wrapper field since it cannot be
# broadcasted with nccl when TP is enabled. # broadcasted with nccl when TP is enabled.
skip_fields.add('prefill_wrapper')
skip_fields.add('decode_wrapper') skip_fields.add('decode_wrapper')
return super().asdict_zerocopy(skip_fields) return super().asdict_zerocopy(skip_fields)
@ -168,6 +195,7 @@ class FlashInferImpl(AttentionImpl):
alibi_slopes: Optional[List[float]], alibi_slopes: Optional[List[float]],
sliding_window: Optional[int], sliding_window: Optional[int],
kv_cache_dtype: str, kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
@ -217,10 +245,14 @@ class FlashInferImpl(AttentionImpl):
self.kv_cache_dtype, self.kv_cache_dtype,
) )
query = query.contiguous(
) # Flashinfer requires query to be contiguous
if prefill_meta := attn_metadata.prefill_metadata: if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run. # We will use flash attention for prefill
assert prefill_meta.block_tables is not None # when kv_cache is not provided.
if kv_cache is None or prefill_meta.block_tables.numel() == 0: # This happens when vllm runs the profiling to
# determine the number of blocks.
if kv_cache is None:
output = flash_attn_varlen_func( output = flash_attn_varlen_func(
q=query, q=query,
k=key, k=key,
@ -235,13 +267,14 @@ class FlashInferImpl(AttentionImpl):
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
) )
else: else:
raise NotImplementedError( assert prefill_meta is not None
"Prefix caching is not supported with flashinfer yet.") assert prefill_meta.prefill_wrapper is not None
output = prefill_meta.prefill_wrapper.forward(query,
kv_cache,
causal=True)
else: else:
assert attn_metadata.decode_metadata is not None assert attn_metadata.decode_metadata is not None
assert attn_metadata.decode_metadata.decode_wrapper is not None assert attn_metadata.decode_metadata.decode_wrapper is not None
query = query.contiguous(
) # Flashinfer requires query to be contiguous
output = attn_metadata.decode_metadata.decode_wrapper.forward( output = attn_metadata.decode_metadata.decode_wrapper.forward(
query, query,
kv_cache, kv_cache,

View File

@ -77,8 +77,9 @@ def get_attn_backend(
return IpexAttnBackend return IpexAttnBackend
elif backend == _Backend.FLASHINFER: elif backend == _Backend.FLASHINFER:
logger.info("Using Flashinfer backend.") logger.info("Using Flashinfer backend.")
logger.warning("Eager mode is required for the Flashinfer backend. " logger.warning(("Flashinfer will be stuck on llma-2-7b,"
"Please make sure --enforce-eager is set.") " please avoid using Flashinfer as the"
"backend when running on llma-2-7b."))
from vllm.attention.backends.flashinfer import FlashInferBackend from vllm.attention.backends.flashinfer import FlashInferBackend
return FlashInferBackend return FlashInferBackend
elif backend == _Backend.PALLAS: elif backend == _Backend.PALLAS:

View File

@ -10,6 +10,17 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
try:
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
FLASHINFER_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
except ImportError:
BatchDecodeWithPagedKVCacheWrapper = None
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
BatchPrefillWithPagedKVCacheWrapper = None
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig, ModelConfig, ParallelConfig, SchedulerConfig,
@ -198,11 +209,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
# Lazy initialization # Lazy initialization
self.model: nn.Module # Set after load_model self.model: nn.Module # Set after load_model
# Set if the backend is flashinfer.
self.flashinfer_workspace_buffer: torch.Tensor
# Set after load_model. # Set after load_model.
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
self.flashinfer_decode_workspace_buffer = None
self.flashinfer_decode_wrapper = None
self.flashinfer_prefill_workspace_buffer = None
self.flashinfer_prefill_wrapper = None
def load_model(self) -> None: def load_model(self) -> None:
with CudaMemoryProfiler() as m: with CudaMemoryProfiler() as m:
self.model = get_model( self.model = get_model(
@ -450,15 +464,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
if curr_sliding_window_blocks is not None: if curr_sliding_window_blocks is not None:
block_table = block_table[ block_table = block_table[
-curr_sliding_window_blocks:] -curr_sliding_window_blocks:]
if self.attn_backend.get_name() == "flashinfer":
paged_kv_indices.extend(block_table)
paged_kv_indptr.append(paged_kv_indptr[-1] +
len(block_table))
last_page_len = seq_data.get_len(
) % self.block_size
if last_page_len == 0:
last_page_len = self.block_size
paged_kv_last_page_len.append(last_page_len)
else: else:
# Only happens when memory profiling runs. # Only happens when memory profiling runs.
block_table = [] block_table = []
@ -505,7 +510,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
for k, v in mm_kwargs.items(): for k, v in mm_kwargs.items():
multi_modal_kwargs_list[k].append(v) multi_modal_kwargs_list[k].append(v)
if _is_block_tables_empty(seq_group_metadata.block_tables): is_profile_run = _is_block_tables_empty(
seq_group_metadata.block_tables)
if is_profile_run:
# During memory profiling, the block tables are not # During memory profiling, the block tables are not
# initialized yet. In this case, we just use a dummy # initialized yet. In this case, we just use a dummy
# slot mapping. # slot mapping.
@ -544,6 +551,27 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
slot = block_number * self.block_size + block_offset slot = block_number * self.block_size + block_offset
slot_mapping.append(slot) slot_mapping.append(slot)
# Prepare input tensors for flashinfer
if self.attn_backend.get_name() == "flashinfer":
seq_len = seq_data.get_len()
# Get the number of valid blocks based on sequence length.
# If seq_len = 16, block_size = 16,
# block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block.
block_table_bound = seq_len // self.block_size + 1 \
if seq_len % self.block_size != 0 \
else seq_len // self.block_size
paged_kv_indices.extend(block_table[:block_table_bound])
paged_kv_indptr.append(paged_kv_indptr[-1] +
block_table_bound)
last_page_len = seq_len % self.block_size
if last_page_len == 0:
last_page_len = self.block_size
paged_kv_last_page_len.append(last_page_len)
batch_size = len(input_tokens) batch_size = len(input_tokens)
max_query_len = max(query_lens) max_query_len = max(query_lens)
max_prefill_seq_len = max(prefill_seq_lens, default=0) max_prefill_seq_len = max(prefill_seq_lens, default=0)
@ -566,6 +594,12 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
seq_lens.append(1) seq_lens.append(1)
block_tables.append([]) block_tables.append([])
lora_index_mapping.append(0) lora_index_mapping.append(0)
if self.attn_backend.get_name() == "flashinfer":
last_paged_kv_indptr = paged_kv_indptr[-1]
paged_kv_indptr.append(last_paged_kv_indptr)
paged_kv_last_page_len.append(0)
batch_size = graph_batch_size batch_size = graph_batch_size
num_decode_tokens = batch_size num_decode_tokens = batch_size
@ -589,9 +623,19 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
) )
assert max_query_len > 0, ("query_lens: {}".format(query_lens)) assert max_query_len > 0, ("query_lens: {}".format(query_lens))
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
seq_lens_tensor = torch.tensor(seq_lens, seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int, dtype=torch.int,
device=self.device) device=self.device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=self.device)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32, dtype=torch.int32,
device=self.device) device=self.device)
@ -600,6 +644,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
dim=0, dim=0,
dtype=seq_start_loc.dtype, dtype=seq_start_loc.dtype,
out=seq_start_loc[1:]) out=seq_start_loc[1:])
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
input_tokens_tensor = torch.tensor(input_tokens, input_tokens_tensor = torch.tensor(input_tokens,
dtype=torch.long, dtype=torch.long,
@ -612,30 +660,30 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
device=self.device) device=self.device)
if self.attn_backend.get_name() == "flashinfer": if self.attn_backend.get_name() == "flashinfer":
if not hasattr(self, "flashinfer_workspace_buffer"): if len(paged_kv_indptr) > 0:
# Allocate 16MB workspace buffer paged_kv_indices_tensor = torch.tensor(paged_kv_indices,
# Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html device='cpu',
self.flashinfer_workspace_buffer = torch.empty( dtype=torch.int)
16 * 1024 * 1024, dtype=torch.uint8, device=self.device) paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr,
paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr, device='cpu',
dtype=torch.int, dtype=torch.int)
device=self.device) paged_kv_last_page_len_tensor = torch.tensor(
paged_kv_indices_tensor = torch.tensor(paged_kv_indices, paged_kv_last_page_len, device='cpu', dtype=torch.int)
dtype=torch.int, else:
device=self.device) paged_kv_indices_tensor = None
paged_kv_last_page_len_tensor = torch.tensor( paged_kv_indptr_tensor = None
paged_kv_last_page_len, dtype=torch.int, device=self.device) paged_kv_last_page_len_tensor = None
kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype, kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype,
self.model_config.dtype) self.model_config.dtype)
attn_metadata = self.attn_backend.make_metadata( attn_metadata = self.attn_backend.make_metadata(
num_prefills=num_prefills, num_prefills=num_prefills,
slot_mapping=slot_mapping_tensor, slot_mapping=slot_mapping_tensor,
num_prefill_tokens=num_prefill_tokens, num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
use_cuda_graph=False,
max_prefill_seq_len=max_prefill_seq_len, max_prefill_seq_len=max_prefill_seq_len,
block_tables=block_tables, block_tables=block_tables,
workspace_buffer=self.flashinfer_workspace_buffer,
paged_kv_indptr=paged_kv_indptr_tensor, paged_kv_indptr=paged_kv_indptr_tensor,
paged_kv_indices=paged_kv_indices_tensor, paged_kv_indices=paged_kv_indices_tensor,
paged_kv_last_page_len=paged_kv_last_page_len_tensor, paged_kv_last_page_len=paged_kv_last_page_len_tensor,
@ -644,25 +692,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
num_kv_heads=self.model_config.get_num_kv_heads( num_kv_heads=self.model_config.get_num_kv_heads(
self.parallel_config), self.parallel_config),
head_dim=self.model_config.get_head_size(), head_dim=self.model_config.get_head_size(),
page_size=16, page_size=self.block_size,
seq_start_loc=seq_start_loc, seq_start_loc=seq_start_loc,
data_type=kv_cache_dtype) query_start_loc=query_start_loc,
device=self.device,
data_type=kv_cache_dtype,
use_cuda_graph=use_captured_graph)
else: else:
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=self.device)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
attn_metadata = self.attn_backend.make_metadata( attn_metadata = self.attn_backend.make_metadata(
num_prefills=num_prefills, num_prefills=num_prefills,
slot_mapping=slot_mapping_tensor, slot_mapping=slot_mapping_tensor,
@ -854,27 +891,97 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
] ]
if self.attn_backend.get_name() == "flashinfer":
# For flashinfer, different batch sizes will share the
# same workspace buffer.
decode_workspace_buffer = \
torch.empty(FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=self.device)
indices_buffer = torch.empty(max_batch_size *
self.cache_config.num_gpu_blocks,
dtype=torch.int32,
device=self.device)
indptr_buffer = torch.empty(max_batch_size + 1,
dtype=torch.int32,
device=self.device)
last_page_len_buffer = torch.empty(max_batch_size,
dtype=torch.int32,
device=self.device)
with graph_capture() as graph_capture_context: with graph_capture() as graph_capture_context:
# NOTE: Capturing the largest batch size first may help reduce the # NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph. # memory usage of CUDA graph.
for batch_size in reversed(batch_size_capture_list): for batch_size in reversed(batch_size_capture_list):
# Create dummy attn_metadata. if self.attn_backend.get_name() == "flashinfer":
attn_metadata = self.attn_backend.make_metadata( indptr_buffer = indptr_buffer[:batch_size + 1]
num_prefills=0, last_page_len_buffer = last_page_len_buffer[:batch_size]
num_prefill_tokens=0,
num_decode_tokens=batch_size, num_qo_heads = self.model_config.get_num_attention_heads(
slot_mapping=slot_mapping[:batch_size], self.parallel_config)
seq_lens=None, num_kv_heads = self.model_config.get_num_kv_heads(
seq_lens_tensor=seq_lens[:batch_size], self.parallel_config)
max_query_len=None, if num_qo_heads // num_kv_heads >= 4:
max_prefill_seq_len=0, use_tensor_cores = True
max_decode_seq_len=self.max_seq_len_to_capture, else:
query_start_loc=None, use_tensor_cores = False
seq_start_loc=None, decode_wrapper = \
context_lens_tensor=None, CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
block_tables=block_tables[:batch_size], decode_workspace_buffer, indptr_buffer, indices_buffer,
use_cuda_graph=True, last_page_len_buffer, "NHD", use_tensor_cores)
) kv_cache_dtype = get_kv_cache_torch_dtype(
self.kv_cache_dtype, self.model_config.dtype)
paged_kv_indptr_tensor_host = torch.arange(
0, batch_size + 1, dtype=torch.int32)
paged_kv_indices_tensor_host = torch.arange(
0, batch_size, dtype=torch.int32)
paged_kv_last_page_len_tensor_host = torch.full(
(batch_size, ), self.block_size, dtype=torch.int32)
query_start_loc_host = torch.arange(0,
batch_size + 1,
dtype=torch.int32)
attn_metadata = self.attn_backend.make_metadata(
num_prefills=0,
slot_mapping=slot_mapping[:batch_size],
num_prefill_tokens=0,
num_decode_tokens=batch_size,
max_prefill_seq_len=0,
block_tables=block_tables,
paged_kv_indptr=paged_kv_indptr_tensor_host,
paged_kv_indices=paged_kv_indices_tensor_host,
paged_kv_last_page_len=
paged_kv_last_page_len_tensor_host,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim=self.model_config.get_head_size(),
page_size=self.block_size,
seq_start_loc=None,
query_start_loc=query_start_loc_host,
device=self.device,
data_type=kv_cache_dtype,
use_cuda_graph=True,
decode_wrapper=decode_wrapper,
prefill_wrapper=None)
attn_metadata.begin_forward()
else:
attn_metadata = self.attn_backend.make_metadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
slot_mapping=slot_mapping[:batch_size],
seq_lens=None,
seq_lens_tensor=seq_lens[:batch_size],
max_query_len=None,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_seq_len_to_capture,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=block_tables[:batch_size],
use_cuda_graph=True,
)
if self.lora_config: if self.lora_config:
lora_mapping = LoRAMapping( lora_mapping = LoRAMapping(
@ -883,8 +990,20 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
) )
self.set_active_loras(set(), lora_mapping) self.set_active_loras(set(), lora_mapping)
graph_runner = CUDAGraphRunner(self.model) graph_runner = CUDAGraphRunner(self.model,
hidden_states = graph_runner.capture( self.attn_backend.get_name())
if self.attn_backend.get_name() == "flashinfer":
graph_runner.flashinfer_indptr_buffer = indptr_buffer
graph_runner.flashinfer_indices_buffer = indices_buffer
graph_runner.flashinfer_last_page_len_buffer = \
last_page_len_buffer
graph_runner.flashinfer_decode_workspace_buffer = \
decode_workspace_buffer
graph_runner.flashinfer_decode_wrapper = \
decode_wrapper
graph_runner.capture(
input_tokens[:batch_size], input_tokens[:batch_size],
input_positions[:batch_size], input_positions[:batch_size],
hidden_states[:batch_size] hidden_states[:batch_size]
@ -918,11 +1037,12 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
self, self,
tensor_dict: Dict[str, Any], tensor_dict: Dict[str, Any],
) -> ModelInputForGPUWithSamplingMetadata: ) -> ModelInputForGPUWithSamplingMetadata:
return ( model_input = \
ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict( ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
tensor_dict, tensor_dict,
attn_backend=self.attn_backend, attn_backend=self.attn_backend,
)) )
return model_input
def prepare_model_input( def prepare_model_input(
self, self,
@ -970,6 +1090,36 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
self.set_active_loras(model_input.lora_requests, self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping) model_input.lora_mapping)
if self.attn_backend.get_name() == "flashinfer":
assert model_input.attn_metadata is not None
assert model_input.input_tokens is not None
if self.flashinfer_decode_workspace_buffer is None:
self.flashinfer_decode_workspace_buffer = torch.empty(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=self.device)
self.flashinfer_decode_wrapper = \
BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_decode_workspace_buffer, "NHD")
self.flashinfer_prefill_workspace_buffer = torch.empty(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=self.device)
self.flashinfer_prefill_wrapper = \
BatchPrefillWithPagedKVCacheWrapper(
self.flashinfer_prefill_workspace_buffer, "NHD")
model_input.attn_metadata.prefill_wrapper = \
self.flashinfer_prefill_wrapper
if model_input.attn_metadata.use_cuda_graph:
batch_size = model_input.input_tokens.shape[0]
model_input.attn_metadata.decode_wrapper = self.graph_runners[
batch_size].flashinfer_decode_wrapper
else:
model_input.attn_metadata.decode_wrapper = \
self.flashinfer_decode_wrapper
model_input.attn_metadata.begin_forward()
# Currently cuda graph is only supported by the decode phase. # Currently cuda graph is only supported by the decode phase.
assert model_input.attn_metadata is not None assert model_input.attn_metadata is not None
prefill_meta = model_input.attn_metadata.prefill_metadata prefill_meta = model_input.attn_metadata.prefill_metadata
@ -1020,13 +1170,22 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
class CUDAGraphRunner: class CUDAGraphRunner:
def __init__(self, model: nn.Module): def __init__(self, model: nn.Module, backend_name: str):
self.model = model self.model = model
self.backend_name = backend_name
self.input_buffers: Dict[str, torch.Tensor] = {} self.input_buffers: Dict[str, torch.Tensor] = {}
self.output_buffers: Dict[str, torch.Tensor] = {} self.output_buffers: Dict[str, torch.Tensor] = {}
self._graph: Optional[torch.cuda.CUDAGraph] = None self._graph: Optional[torch.cuda.CUDAGraph] = None
self.flashinfer_decode_workspace_buffer: Optional[torch.Tensor] = None
self.flashinfer_indptr_buffer: Optional[torch.Tensor] = None
self.flashinfer_indices_buffer: Optional[torch.Tensor] = None
self.flashinfer_last_page_len_buffer: Optional[torch.Tensor] = None
self.flashinfer_decode_wrapper: Optional[
CUDAGraphBatchDecodeWithPagedKVCacheWrapper] = None
@property @property
def graph(self): def graph(self):
assert self._graph is not None assert self._graph is not None
@ -1079,14 +1238,23 @@ class CUDAGraphRunner:
torch.cuda.synchronize() torch.cuda.synchronize()
# Save the input and output buffers. # Save the input and output buffers.
self.input_buffers = { if self.backend_name == "flashinfer":
"input_ids": input_ids, self.input_buffers = {
"positions": positions, "input_ids": input_ids,
"kv_caches": kv_caches, "positions": positions,
"slot_mapping": attn_metadata.slot_mapping, "kv_caches": kv_caches,
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, "slot_mapping": attn_metadata.slot_mapping,
"block_tables": attn_metadata.decode_metadata.block_tables, }
} else:
self.input_buffers = {
"input_ids": input_ids,
"positions": positions,
"kv_caches": kv_caches,
"slot_mapping": attn_metadata.slot_mapping,
"seq_lens_tensor":
attn_metadata.decode_metadata.seq_lens_tensor,
"block_tables": attn_metadata.decode_metadata.block_tables,
}
self.output_buffers = {"hidden_states": hidden_states} self.output_buffers = {"hidden_states": hidden_states}
return hidden_states return hidden_states
@ -1106,10 +1274,12 @@ class CUDAGraphRunner:
self.input_buffers["positions"].copy_(positions, non_blocking=True) self.input_buffers["positions"].copy_(positions, non_blocking=True)
self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
non_blocking=True) non_blocking=True)
self.input_buffers["seq_lens_tensor"].copy_( if self.backend_name != "flashinfer":
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) self.input_buffers["seq_lens_tensor"].copy_(
self.input_buffers["block_tables"].copy_( attn_metadata.decode_metadata.seq_lens_tensor,
attn_metadata.decode_metadata.block_tables, non_blocking=True) non_blocking=True)
self.input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True)
# Run the graph. # Run the graph.
self.graph.replay() self.graph.replay()