From cf5cb1e33eed16b2f0d5fe6268bf5705a4d0ea5a Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 26 Sep 2023 22:27:13 -0700 Subject: [PATCH] Allocate more shared memory to attention kernel (#1154) --- csrc/attention/attention_kernels.cu | 5 +++++ csrc/cuda_utils.cpp | 13 +++++++++++++ csrc/cuda_utils_kernels.cu | 14 ++++++++++++++ setup.py | 11 +++++++++++ tests/kernels/test_attention.py | 8 +++++++- vllm/utils.py | 13 ++++++++++++- vllm/worker/worker.py | 26 +++++++++++++++++++++++++- 7 files changed, 87 insertions(+), 3 deletions(-) create mode 100644 csrc/cuda_utils.cpp create mode 100644 csrc/cuda_utils_kernels.cu diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 3fc5860bf147..8955b503bdd1 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -341,6 +341,9 @@ __global__ void single_query_cached_kv_attention_kernel( } // namespace vllm #define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ + cudaFuncSetAttribute( \ + vllm::single_query_cached_kv_attention_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ vllm::single_query_cached_kv_attention_kernel \ <<>>( \ out_ptr, \ @@ -401,6 +404,8 @@ void single_query_cached_kv_attention_launcher( int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; int logits_size = padded_max_context_len * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len + // Keep that in sync with the logic here! int shared_mem_size = std::max(logits_size, outputs_size); dim3 grid(num_heads, num_seqs); diff --git a/csrc/cuda_utils.cpp b/csrc/cuda_utils.cpp new file mode 100644 index 000000000000..e7f22ec89d7b --- /dev/null +++ b/csrc/cuda_utils.cpp @@ -0,0 +1,13 @@ +#include + +int get_device_attribute( + int attribute, + int device_id); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "get_device_attribute", + &get_device_attribute, + "Gets the specified device attribute."); +} + diff --git a/csrc/cuda_utils_kernels.cu b/csrc/cuda_utils_kernels.cu new file mode 100644 index 000000000000..f1c30fe7ea99 --- /dev/null +++ b/csrc/cuda_utils_kernels.cu @@ -0,0 +1,14 @@ +int get_device_attribute( + int attribute, + int device_id) +{ + int device, value; + if (device_id < 0) { + cudaGetDevice(&device); + } + else { + device = device_id; + } + cudaDeviceGetAttribute(&value, static_cast(attribute), device); + return value; +} diff --git a/setup.py b/setup.py index b7c0d9071fec..8b2ad97dd540 100644 --- a/setup.py +++ b/setup.py @@ -195,6 +195,17 @@ quantization_extension = CUDAExtension( ) ext_modules.append(quantization_extension) +# Misc. CUDA utils. +cuda_utils_extension = CUDAExtension( + name="vllm.cuda_utils", + sources=["csrc/cuda_utils.cpp", "csrc/cuda_utils_kernels.cu"], + extra_compile_args={ + "cxx": CXX_FLAGS, + "nvcc": NVCC_FLAGS, + }, +) +ext_modules.append(cuda_utils_extension) + def get_path(*filepath) -> str: return os.path.join(ROOT_DIR, *filepath) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 18985669d159..813f6fdb59b2 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -7,8 +7,12 @@ from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask from vllm import attention_ops +from vllm.utils import get_max_shared_memory_bytes -MAX_SEQ_LEN = 8192 +FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 +# This will change depending on the compute capability. +# - 512 as a buffer +MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 NUM_BLOCKS = 128 # Arbitrary values for testing DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -135,6 +139,7 @@ def test_single_query_cached_kv_attention( device="cuda") context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] + context_lens[-1] = MAX_SEQ_LEN max_context_len = max(context_lens) context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda") @@ -243,6 +248,7 @@ def test_multi_query_kv_attention( torch.cuda.manual_seed(seed) seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs) + seq_lens[-1] = MAX_SEQ_LEN num_tokens = sum(seq_lens) scale = float(1.0 / (head_size**0.5)) diff --git a/vllm/utils.py b/vllm/utils.py index eb10b3f50576..0e17e9070489 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1,10 +1,12 @@ import enum -from platform import uname import uuid +from platform import uname import psutil import torch +from vllm import cuda_utils + class Device(enum.Enum): GPU = enum.auto() @@ -25,6 +27,15 @@ class Counter: self.counter = 0 +def get_max_shared_memory_bytes(gpu: int = 0) -> int: + """Returns the maximum shared memory per thread block in bytes.""" + # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html + cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 # pylint: disable=invalid-name + max_shared_mem = cuda_utils.get_device_attribute( + cudaDevAttrMaxSharedMemoryPerBlockOptin, gpu) + return int(max_shared_mem) + + def get_gpu_memory(gpu: int = 0) -> int: """Returns the total memory of the GPU in bytes.""" return torch.cuda.get_device_properties(gpu).total_memory diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 586c90e0049c..3239a819794e 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -13,7 +13,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.sampling_params import SamplingParams from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.worker.cache_engine import CacheEngine -from vllm.utils import get_gpu_memory +from vllm.utils import get_gpu_memory, get_max_shared_memory_bytes class Worker: @@ -136,6 +136,10 @@ class Worker: def init_cache_engine(self, cache_config: CacheConfig) -> None: self.cache_config = cache_config self.block_size = cache_config.block_size + + _check_if_can_support_max_seq_len(self.scheduler_config.max_model_len, + self.block_size) + self.cache_engine = CacheEngine(self.cache_config, self.model_config, self.parallel_config) self.cache_events = self.cache_engine.events @@ -347,3 +351,23 @@ def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]: def _pad_to_max(x: List[int], max_len: int) -> List[int]: return x + [0] * (max_len - len(x)) + + +def _check_if_can_support_max_seq_len(max_seq_len: int, + block_size: int) -> None: + # Follows the logic in + # attention_kernels.cu::single_query_cached_kv_attention_launcher + max_shared_mem = get_max_shared_memory_bytes() + float32_bytes = torch.finfo(torch.float).bits // 8 + padded_max_seq_len = ( + (max_seq_len + block_size - 1) / block_size) * block_size + # padded_max_seq_len + extra buffer + required_shared_mem = (padded_max_seq_len + 512) * float32_bytes + if padded_max_seq_len * float32_bytes > max_shared_mem: + raise RuntimeError( + f"vLLM cannot currently support max_model_len={max_seq_len} " + f"with block_size={block_size} on GPU with compute " + f"capability {torch.cuda.get_device_capability()} " + f"(required shared memory {required_shared_mem} > " + f"available shared memory {max_shared_mem}). " + "This will be fixed in a future release.")