diff --git a/csrc/cache.h b/csrc/cache.h index 718a5f6cfd7f..4c142ce17f1b 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -24,6 +24,14 @@ void reshape_and_cache( const std::string& kv_cache_dtype, const float kv_scale); +void reshape_and_cache_flash( + torch::Tensor& key, + torch::Tensor& value, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype); + // Just for unittest void convert_fp8( torch::Tensor& src_cache, diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 24aaa2ff3e26..42f884c76c62 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -215,6 +215,41 @@ __global__ void reshape_and_cache_kernel( } } +template +__global__ void reshape_and_cache_flash_kernel( + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, head_size] + scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, head_size] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int block_stride, + const int key_stride, + const int value_stride, + const int num_heads, + const int head_size, + const int block_size) { + const int64_t token_idx = blockIdx.x; + const int64_t slot_idx = slot_mapping[token_idx]; + // NOTE: slot_idx can be -1 if the token is padded + if (slot_idx < 0) { + return; + } + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + const int n = num_heads * head_size; + for (int i = threadIdx.x; i < n; i += blockDim.x) { + const int64_t src_key_idx = token_idx * key_stride + i; + const int64_t src_value_idx = token_idx * value_stride + i; + const int head_idx = i / head_size; + const int head_offset = i % head_size; + const int64_t tgt_value_idx = block_idx * block_stride + + block_offset * num_heads * head_size + + head_idx * head_size + + head_offset; + k_cache[tgt_value_idx] = key[src_key_idx]; + v_cache[tgt_value_idx] = value[src_value_idx]; + } +} } // namespace vllm #define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \ @@ -275,6 +310,51 @@ void reshape_and_cache( } } +void reshape_and_cache_flash( + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& slot_mapping, // [num_tokens] + const std::string& kv_cache_dtype) +{ + // FIXME: only support auto datatype, does not support fp8 + if (kv_cache_dtype != "auto") { + TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); + } + int num_tokens = key.size(0); + int num_heads = key.size(1); + int head_size = key.size(2); + int block_size = k_cache.size(1); + + int key_stride = key.stride(0); + int value_stride = value.stride(0); + int block_stride = k_cache.stride(0); + TORCH_CHECK(k_cache.stride(0) == v_cache.stride(0)); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * head_size, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + key.scalar_type(), + "reshape_and_cache_flash", + [&] { + vllm::reshape_and_cache_flash_kernel<<>>( + key.data_ptr(), + value.data_ptr(), + k_cache.data_ptr(), + v_cache.data_ptr(), + slot_mapping.data_ptr(), + block_stride, + key_stride, + value_stride, + num_heads, + head_size, + block_size); + }); +} + namespace vllm { template diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 9839bfc0331c..173e0b1732e1 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -96,6 +96,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "reshape_and_cache", &reshape_and_cache, "Reshape the key and value tensors and cache them"); + cache_ops.def( + "reshape_and_cache_flash", + &reshape_and_cache_flash, + "Reshape the key and value tensors and cache them"); cache_ops.def( "convert_fp8", &convert_fp8, diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 97cff623c5e1..d75279dd9cfa 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -2,12 +2,15 @@ Run `pytest tests/basic_correctness/test_basic_correctness.py`. """ +import os + import pytest MODELS = [ "facebook/opt-125m", "meta-llama/Llama-2-7b-hf", ] +VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND" @pytest.mark.parametrize("model", MODELS) @@ -23,11 +26,18 @@ def test_models( max_tokens: int, enforce_eager: bool, ) -> None: + backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND) + if backend_by_env_var == "FLASHINFER" and enforce_eager is False: + pytest.skip("Skipping non-eager test for FlashInferBackend.") + hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model - vllm_model = vllm_runner(model, dtype=dtype, enforce_eager=enforce_eager) + vllm_model = vllm_runner(model, + dtype=dtype, + enforce_eager=enforce_eager, + gpu_memory_utilization=0.7) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py index 77aa90b12bf8..527452630c9f 100644 --- a/tests/distributed/test_basic_distributed_correctness.py +++ b/tests/distributed/test_basic_distributed_correctness.py @@ -18,6 +18,7 @@ import torch MODELS = [ os.environ["TEST_DIST_MODEL"], ] +VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND" @pytest.mark.skipif(torch.cuda.device_count() < 2, @@ -33,16 +34,19 @@ def test_models( dtype: str, max_tokens: int, ) -> None: + enforce_eager = False + backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND) + if backend_by_env_var == "FLASHINFER": + enforce_eager = True hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model - vllm_model = vllm_runner( - model, - dtype=dtype, - tensor_parallel_size=2, - ) + vllm_model = vllm_runner(model, + dtype=dtype, + tensor_parallel_size=2, + enforce_eager=enforce_eager) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/tests/kernels/conftest.py b/tests/kernels/conftest.py index d26da2c7fe4e..4f2f9cc3dac7 100644 --- a/tests/kernels/conftest.py +++ b/tests/kernels/conftest.py @@ -1,8 +1,14 @@ import pytest -from vllm.utils import create_kv_caches_with_random +from vllm.utils import (create_kv_caches_with_random, + create_kv_caches_with_random_flash) @pytest.fixture() def kv_cache_factory(): return create_kv_caches_with_random + + +@pytest.fixture() +def kv_cache_factory_flashinfer(): + return create_kv_caches_with_random_flash diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index d1051fd7e2f4..ca215bb75837 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -5,6 +5,7 @@ import pytest import torch from vllm import _custom_ops as ops +from vllm._C import cache_ops from vllm.utils import is_hip COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] @@ -191,6 +192,82 @@ def test_reshape_and_cache( assert torch.allclose(value_cache, cloned_value_cache) +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) +@torch.inference_mode() +def test_reshape_and_cache_flash( + kv_cache_factory_flashinfer, + num_tokens: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + seed: int, + device: str, + kv_cache_dtype: str, +) -> None: + if kv_cache_dtype == "fp8": + pytest.skip() + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Create a random slot mapping. + num_slots = block_size * num_blocks + slot_mapping = random.sample(range(num_slots), num_tokens) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device='cuda') + + qkv = torch.randn(num_tokens, + 3, + num_heads, + head_size, + dtype=dtype, + device=device) + _, key, value = qkv.unbind(dim=1) + + # Create the KV caches. + key_caches, value_caches = kv_cache_factory_flashinfer( + num_blocks, + block_size, + 1, + num_heads, + head_size, + kv_cache_dtype, + dtype, + ) + key_cache, value_cache = key_caches[0], value_caches[0] + + # Clone the KV caches. + cloned_key_cache = key_cache.clone() + cloned_value_cache = value_cache.clone() + + # Call the reshape_and_cache kernel. + cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, + slot_mapping, kv_cache_dtype) + + # Run the reference implementation. + block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor') + block_indicies = block_indicies.cpu().tolist() + block_offsets = slot_mapping % block_size + block_offsets = block_offsets.cpu().tolist() + for i in range(num_tokens): + block_idx = block_indicies[i] + block_offset = block_offsets[i] + cloned_key_cache[block_idx, block_offset, :, :] = key[i] + cloned_value_cache[block_idx, block_offset, :, :] = value[i] + + assert torch.allclose(key_cache, cloned_key_cache) + assert torch.allclose(value_cache, cloned_value_cache) + + @pytest.mark.parametrize("direction", COPYING_DIRECTION) @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) @pytest.mark.parametrize("num_heads", NUM_HEADS) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index b43f646fec88..5b5643748747 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -222,6 +222,18 @@ def reshape_and_cache( slot_mapping, kv_cache_dtype, kv_scale) +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, +) -> None: + vllm_cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, + slot_mapping, kv_cache_dtype) + + def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor, block_mapping: torch.Tensor) -> None: vllm_cache_ops.copy_blocks(key_caches, value_caches, block_mapping) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index be747c990036..61c9c81d8a7b 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, fields -from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar +from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type, + TypeVar) import torch @@ -15,7 +16,7 @@ class AttentionBackend(ABC): @staticmethod @abstractmethod - def make_metadata(*args, **kwargs) -> "AttentionMetadata": + def make_metadata(*args, **kwargs) -> "AttentionMetadataPerStage": raise NotImplementedError @staticmethod @@ -50,13 +51,17 @@ class AttentionBackend(ABC): class AttentionMetadataPerStage: """Attention metadata for a specific stage. I.e., prefill or decode.""" - def asdict_zerocopy(self) -> Dict[str, Any]: + def asdict_zerocopy(self, + skip_fields: Optional[Set[str]] = None + ) -> Dict[str, Any]: """Similar to dataclasses.asdict, but avoids deepcopying.""" + if skip_fields is None: + skip_fields = set() # Note that if we add dataclasses as fields, they will need # similar handling. return { field.name: getattr(self, field.name) - for field in fields(self) + for field in fields(self) if field.name not in skip_fields } diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py new file mode 100644 index 000000000000..8ab4b1f12ee3 --- /dev/null +++ b/vllm/attention/backends/flashinfer.py @@ -0,0 +1,220 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Set, Tuple, Type + +try: + import flashinfer + from flash_attn import flash_attn_varlen_func + from flashinfer import BatchDecodeWithPagedKVCacheWrapper +except ImportError: + flashinfer = None + flash_attn_varlen_func = None + BatchDecodeWithPagedKVCacheWrapper = None + +import torch + +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, + AttentionMetadataPerStage) + + +class FlashInferBackend(AttentionBackend): + + @staticmethod + def get_impl_cls() -> Type["FlashInferImpl"]: + return FlashInferImpl + + @staticmethod + def make_metadata(*args, **kwargs) -> "FlashInferMetadata": + return FlashInferMetadata(*args, **kwargs) + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (num_blocks, 2, block_size, num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + raise NotImplementedError + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + ) -> None: + raise NotImplementedError + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [64, 128, 256] + + +@dataclass +class FlashInferMetadata(AttentionMetadataPerStage): + + is_prompt: bool + + use_cuda_graph: bool = False + + decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None + + # Metadata for the prefill stage since we still + # use flash attention for prefill. + seq_start_loc: Optional[torch.Tensor] = None + max_seq_len: Optional[int] = None + block_tables: Optional[torch.Tensor] = None + + # Metadata for the decode stage + # Workspace buffer required by the kernel, the buffer should not + # be allocated/deacollated by the FalshInfermetadata object. + workspace_buffer: Optional[torch.Tensor] = None + # An example for paged_kv_indices, paged_kv_indptr: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + # The indptr of the paged kv cache, shape: [batch_size + 1] + paged_kv_indptr: Optional[torch.Tensor] = None + # The page indices of the paged kv cache + paged_kv_indices: Optional[torch.Tensor] = None + # The number of entries in the last page of each request in + # the paged kv cache, shape: [batch_size] + paged_kv_last_page_len: Optional[torch.Tensor] = None + # The number of query/output heads + num_qo_heads: Optional[int] = None + # The number of key/value heads + num_kv_heads: Optional[int] = None + # The dimension of the attention heads + head_dim: Optional[int] = None + # Block size of vllm + page_size: Optional[int] = None + # The data type of the paged kv cache + data_type: torch.dtype = None + + def __post_init__(self): + # Refer to + # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 + supported_head_sizes = FlashInferBackend.get_supported_head_sizes() + if self.head_dim is not None and self.head_dim \ + not in supported_head_sizes: + raise ValueError( + f"Only {supported_head_sizes} are supported for head_dim,", + f"received {self.head_dim}.") + + # When using flashinfer, we are also creating the FlashInferMetadata, + # which will also call post_init by default, here we want to skip the + # post_init if it's the prefill phase. + if not self.is_prompt: + self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + self.workspace_buffer, "NHD") + self.decode_wrapper.begin_forward( + self.paged_kv_indptr, + self.paged_kv_indices, + self.paged_kv_last_page_len, + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + # Disable flashinfer's pos encoding and use vllm's rope. + pos_encoding_mode="NONE", + data_type=self.data_type) + + def asdict_zerocopy(self, + skip_fields: Optional[Set[str]] = None + ) -> Dict[str, Any]: + if skip_fields is None: + skip_fields = set() + # We need to skip the decode_wrapper field since it cannot be + # broadcasted with nccl when TP is enabled. + skip_fields.add('decode_wrapper') + return super().asdict_zerocopy(skip_fields) + + +class FlashInferImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + ) -> None: + if sliding_window is not None: + raise ValueError("Sliding window is not supported in FlashInfer.") + self.sliding_window = (-1, -1) + self.alibi_slopes = alibi_slopes + self.scale = scale + self.num_heads = num_heads + self.head_size = head_size + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + + def forward(self, query: torch.Tensor, key: torch.Tensor, + value: torch.Tensor, kv_cache: Optional[torch.Tensor], + attn_metadata: AttentionMetadata[FlashInferMetadata], + kv_scale: float): + num_tokens, hidden_size = query.shape + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if attn_metadata.num_prefill_tokens > 0: + assert attn_metadata.num_decode_tokens == 0, ( + "Chunked prefill is not supported with flashinfer yet.") + if attn_metadata.num_decode_tokens > 0: + assert attn_metadata.num_prefill_tokens == 0, ( + "Chunked prefill is not supported with flashinfer yet.") + + if kv_cache is not None: + # Use the same reshape and cache kernel as flash attention. + ops.reshape_and_cache_flash( + key, + value, + kv_cache[:, 0], + kv_cache[:, 1], + attn_metadata.slot_mapping.flatten(), + attn_metadata.kv_cache_dtype, + ) + + if prefill_meta := attn_metadata.prefill_metadata: + assert prefill_meta.block_tables is not None + if kv_cache is None or prefill_meta.block_tables.numel() == 0: + output = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_seq_len, + max_seqlen_k=prefill_meta.max_seq_len, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + ) + else: + raise NotImplementedError( + "Prefix caching is not supported with flashinfer yet.") + else: + assert attn_metadata.decode_metadata is not None + assert attn_metadata.decode_metadata.decode_wrapper is not None + query = query.contiguous( + ) # Flashinfer requires query to be contiguous + output = attn_metadata.decode_metadata.decode_wrapper.forward( + query, + kv_cache, + sm_scale=self.scale, + ) + return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 7ae8c31fae1a..34da0f6c6cdf 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -17,6 +17,7 @@ class _Backend(enum.Enum): XFORMERS = enum.auto() ROCM_FLASH = enum.auto() TORCH_SDPA = enum.auto() + FLASHINFER = enum.auto() @lru_cache(maxsize=None) @@ -41,6 +42,11 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]: logger.info("Using Torch SDPA backend.") from vllm.attention.backends.torch_sdpa import TorchSDPABackend return TorchSDPABackend + elif backend == _Backend.FLASHINFER: + logger.info("Using Flashinfer backend.") + logger.warning("Eager mode is enforced for the Flashinfer backend. ") + from vllm.attention.backends.flashinfer import FlashInferBackend + return FlashInferBackend else: raise ValueError("Invalid attention backend.") diff --git a/vllm/config.py b/vllm/config.py index 3bdd3f774bc2..fe54c54bed48 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -298,6 +298,11 @@ class ModelConfig: return max(1, total_num_kv_heads // parallel_config.tensor_parallel_size) + def get_num_attention_heads(self, + parallel_config: "ParallelConfig") -> int: + return self.hf_text_config.num_attention_heads // \ + parallel_config.tensor_parallel_size + def get_num_layers(self, parallel_config: "ParallelConfig") -> int: total_num_hidden_layers = self.hf_text_config.num_hidden_layers return total_num_hidden_layers // parallel_config.pipeline_parallel_size diff --git a/vllm/sequence.py b/vllm/sequence.py index 0e931ebbb657..8caf97d30d53 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -579,8 +579,10 @@ class SequenceGroupMetadata: query tokens for prefill, we don't need sampling. token_chunk_size: The number of tokens to be processed (per sequence). None if chunking is not required. - state: Internal state tied to this sequence group. lora_request: LoRA request. + computed_block_nums: The block numbers that are already computed, + used in prefix caching. + state: Internal state tied to this sequence group. multi_modal_data: Multi modal data. """ diff --git a/vllm/utils.py b/vllm/utils.py index ce55253ce219..b06c8508757c 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -355,21 +355,9 @@ def _generate_random_fp8( del tensor_tmp -def create_kv_caches_with_random( - num_blocks: int, - block_size: int, - num_layers: int, - num_heads: int, - head_size: int, - cache_dtype: Optional[Union[str, torch.dtype]], - model_dtype: Optional[Union[str, torch.dtype]] = None, - seed: int = 0, - device: Optional[str] = "cuda", -) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - +def get_kv_cache_torch_dtype( + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype: if isinstance(cache_dtype, str): if cache_dtype == "auto": if isinstance(model_dtype, str): @@ -388,6 +376,55 @@ def create_kv_caches_with_random( torch_dtype = cache_dtype else: raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") + return torch_dtype + + +def create_kv_caches_with_random_flash( + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None, + seed: int = 0, + device: Optional[str] = "cuda", +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + assert cache_dtype != "fp8" + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) + key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) + scale = head_size**-0.5 + key_caches, value_caches = [], [] + for _ in range(num_layers): + key_value_cache = torch.empty(size=key_value_cache_shape, + dtype=torch_dtype, + device=device) + key_value_cache.uniform_(-scale, scale) + key_caches.append(key_value_cache[:, 0]) + value_caches.append(key_value_cache[:, 1]) + return key_caches, value_caches + + +def create_kv_caches_with_random( + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None, + seed: int = 0, + device: Optional[str] = "cuda", +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) scale = head_size**-0.5 x = 16 // torch.tensor([], dtype=torch_dtype).element_size() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index bbb1f5205af5..ab248596490f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -9,6 +9,7 @@ import torch.nn as nn from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, get_attn_backend) +from vllm.attention.backends.flashinfer import FlashInferBackend from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce @@ -23,8 +24,8 @@ from vllm.model_executor.model_loader import get_model from vllm.sampling_params import SamplingParams from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData, SequenceGroupMetadata) -from vllm.utils import (CudaMemoryProfiler, is_hip, is_pin_memory_available, - make_tensor_with_pad) +from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, + is_pin_memory_available, make_tensor_with_pad) logger = init_logger(__name__) @@ -155,6 +156,9 @@ class ModelRunner: # (max batch size to capture, max context len to capture / block size). self.graph_block_tables: torch.Tensor # Set after initial profiling. + # Set if the backend is flashinfer. + self.flashinfer_workspace_buffer: torch.Tensor + def load_model(self) -> None: with CudaMemoryProfiler() as m: self.model = get_model( @@ -315,6 +319,7 @@ class ModelRunner: # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] + # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, # where start_idx is max(0, seq_len - sliding_window). # For example, if the prompt len is 10, sliding window is 8, and @@ -390,18 +395,26 @@ class ModelRunner: dtype=seq_start_loc.dtype, out=seq_start_loc[1:]) - attn_metadata = self.attn_backend.make_metadata( - is_prompt=True, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - max_seq_len=max_seq_len, - subquery_start_loc=subquery_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, - ) + if self.attn_backend is FlashInferBackend: + attn_metadata = self.attn_backend.make_metadata( + is_prompt=True, + use_cuda_graph=False, + seq_start_loc=seq_start_loc, + max_seq_len=max_seq_len, + block_tables=block_tables) + else: + attn_metadata = self.attn_backend.make_metadata( + is_prompt=True, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_seq_len=max_seq_len, + subquery_start_loc=subquery_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False, + ) return PreparePromptMetadata( input_tokens=input_tokens, @@ -429,6 +442,24 @@ class ModelRunner: lora_prompt_mapping: List[int] = [] lora_requests: Set[LoRARequest] = set() + # The following fields are only for flashinfer + # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout + # for the precise definition of the following fields. + # An example: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + paged_kv_indices: List[int] = [] + # 0 at the beginning of paged_kv_indptr indicates the start of the + # first request’s page indices in the paged_kv_indices list. + paged_kv_indptr: List[int] = [0] + # paged_kv_last_page_len is the length of the last page of each request + paged_kv_last_page_len: List[int] = [] + if len(seq_group_metadata_list) == 0: return PrepareDecodeMetadata.empty() @@ -469,6 +500,13 @@ class ModelRunner: block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) + paged_kv_indices.extend(block_table) + paged_kv_indptr.append(paged_kv_indptr[-1] + len(block_table)) + last_page_len = seq_data.get_len() % self.block_size + if last_page_len == 0: + last_page_len = self.block_size + paged_kv_last_page_len.append(last_page_len) + # vLLM uses cuda graph only for decoding requests. # See `capture_model` API for more details. # For decoding requests, batch_size == input_tokens. @@ -518,18 +556,51 @@ class ModelRunner: device=self.device, ) - attn_metadata = self.attn_backend.make_metadata( - is_prompt=False, - seq_lens=None, - seq_lens_tensor=seq_lens_tensor, - max_query_len=None, - max_seq_len=max_seq_len, - subquery_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=block_tables, - use_cuda_graph=use_captured_graph, - ) + if self.attn_backend is FlashInferBackend: + if not hasattr(self, "flashinfer_workspace_buffer"): + # Allocate 16MB workspace buffer + # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html + self.flashinfer_workspace_buffer = torch.empty( + 16 * 1024 * 1024, dtype=torch.uint8, device=self.device) + paged_kv_indptr = torch.tensor(paged_kv_indptr, + dtype=torch.int, + device=self.device) + paged_kv_indices = torch.tensor(paged_kv_indices, + dtype=torch.int, + device=self.device) + paged_kv_last_page_len = torch.tensor(paged_kv_last_page_len, + dtype=torch.int, + device=self.device) + kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype, + self.model_config.dtype) + + attn_metadata = self.attn_backend.make_metadata( + is_prompt=False, + use_cuda_graph=False, + workspace_buffer=self.flashinfer_workspace_buffer, + paged_kv_indptr=paged_kv_indptr, + paged_kv_indices=paged_kv_indices, + paged_kv_last_page_len=paged_kv_last_page_len, + num_qo_heads=self.model_config.get_num_attention_heads( + self.parallel_config), + num_kv_heads=self.model_config.get_num_kv_heads( + self.parallel_config), + head_dim=self.model_config.get_head_size(), + page_size=self.block_size, + data_type=kv_cache_dtype) + else: + attn_metadata = self.attn_backend.make_metadata( + is_prompt=False, + seq_lens=None, + seq_lens_tensor=seq_lens_tensor, + max_query_len=None, + max_seq_len=max_seq_len, + subquery_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + ) return PrepareDecodeMetadata( input_tokens=input_tokens, input_positions=input_positions,