mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:05:01 +08:00
Allocate more shared memory to attention kernel (#1154)
This commit is contained in:
parent
03ffd0a022
commit
cf5cb1e33e
@ -341,6 +341,9 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \
|
#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \
|
||||||
|
cudaFuncSetAttribute( \
|
||||||
|
vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>, \
|
||||||
|
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \
|
||||||
vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
|
vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
|
||||||
<<<grid, block, shared_mem_size, stream>>>( \
|
<<<grid, block, shared_mem_size, stream>>>( \
|
||||||
out_ptr, \
|
out_ptr, \
|
||||||
@ -401,6 +404,8 @@ void single_query_cached_kv_attention_launcher(
|
|||||||
int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
|
int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
|
||||||
int logits_size = padded_max_context_len * sizeof(float);
|
int logits_size = padded_max_context_len * sizeof(float);
|
||||||
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
||||||
|
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
|
||||||
|
// Keep that in sync with the logic here!
|
||||||
int shared_mem_size = std::max(logits_size, outputs_size);
|
int shared_mem_size = std::max(logits_size, outputs_size);
|
||||||
|
|
||||||
dim3 grid(num_heads, num_seqs);
|
dim3 grid(num_heads, num_seqs);
|
||||||
|
|||||||
13
csrc/cuda_utils.cpp
Normal file
13
csrc/cuda_utils.cpp
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
int get_device_attribute(
|
||||||
|
int attribute,
|
||||||
|
int device_id);
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def(
|
||||||
|
"get_device_attribute",
|
||||||
|
&get_device_attribute,
|
||||||
|
"Gets the specified device attribute.");
|
||||||
|
}
|
||||||
|
|
||||||
14
csrc/cuda_utils_kernels.cu
Normal file
14
csrc/cuda_utils_kernels.cu
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
int get_device_attribute(
|
||||||
|
int attribute,
|
||||||
|
int device_id)
|
||||||
|
{
|
||||||
|
int device, value;
|
||||||
|
if (device_id < 0) {
|
||||||
|
cudaGetDevice(&device);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
device = device_id;
|
||||||
|
}
|
||||||
|
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device);
|
||||||
|
return value;
|
||||||
|
}
|
||||||
11
setup.py
11
setup.py
@ -195,6 +195,17 @@ quantization_extension = CUDAExtension(
|
|||||||
)
|
)
|
||||||
ext_modules.append(quantization_extension)
|
ext_modules.append(quantization_extension)
|
||||||
|
|
||||||
|
# Misc. CUDA utils.
|
||||||
|
cuda_utils_extension = CUDAExtension(
|
||||||
|
name="vllm.cuda_utils",
|
||||||
|
sources=["csrc/cuda_utils.cpp", "csrc/cuda_utils_kernels.cu"],
|
||||||
|
extra_compile_args={
|
||||||
|
"cxx": CXX_FLAGS,
|
||||||
|
"nvcc": NVCC_FLAGS,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
ext_modules.append(cuda_utils_extension)
|
||||||
|
|
||||||
|
|
||||||
def get_path(*filepath) -> str:
|
def get_path(*filepath) -> str:
|
||||||
return os.path.join(ROOT_DIR, *filepath)
|
return os.path.join(ROOT_DIR, *filepath)
|
||||||
|
|||||||
@ -7,8 +7,12 @@ from xformers import ops as xops
|
|||||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||||
|
|
||||||
from vllm import attention_ops
|
from vllm import attention_ops
|
||||||
|
from vllm.utils import get_max_shared_memory_bytes
|
||||||
|
|
||||||
MAX_SEQ_LEN = 8192
|
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||||
|
# This will change depending on the compute capability.
|
||||||
|
# - 512 as a buffer
|
||||||
|
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
|
||||||
NUM_BLOCKS = 128 # Arbitrary values for testing
|
NUM_BLOCKS = 128 # Arbitrary values for testing
|
||||||
|
|
||||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||||
@ -135,6 +139,7 @@ def test_single_query_cached_kv_attention(
|
|||||||
device="cuda")
|
device="cuda")
|
||||||
|
|
||||||
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||||
|
context_lens[-1] = MAX_SEQ_LEN
|
||||||
max_context_len = max(context_lens)
|
max_context_len = max(context_lens)
|
||||||
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
|
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
|
||||||
|
|
||||||
@ -243,6 +248,7 @@ def test_multi_query_kv_attention(
|
|||||||
torch.cuda.manual_seed(seed)
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
|
seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
|
||||||
|
seq_lens[-1] = MAX_SEQ_LEN
|
||||||
num_tokens = sum(seq_lens)
|
num_tokens = sum(seq_lens)
|
||||||
|
|
||||||
scale = float(1.0 / (head_size**0.5))
|
scale = float(1.0 / (head_size**0.5))
|
||||||
|
|||||||
@ -1,10 +1,12 @@
|
|||||||
import enum
|
import enum
|
||||||
from platform import uname
|
|
||||||
import uuid
|
import uuid
|
||||||
|
from platform import uname
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm import cuda_utils
|
||||||
|
|
||||||
|
|
||||||
class Device(enum.Enum):
|
class Device(enum.Enum):
|
||||||
GPU = enum.auto()
|
GPU = enum.auto()
|
||||||
@ -25,6 +27,15 @@ class Counter:
|
|||||||
self.counter = 0
|
self.counter = 0
|
||||||
|
|
||||||
|
|
||||||
|
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
||||||
|
"""Returns the maximum shared memory per thread block in bytes."""
|
||||||
|
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
|
||||||
|
cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 # pylint: disable=invalid-name
|
||||||
|
max_shared_mem = cuda_utils.get_device_attribute(
|
||||||
|
cudaDevAttrMaxSharedMemoryPerBlockOptin, gpu)
|
||||||
|
return int(max_shared_mem)
|
||||||
|
|
||||||
|
|
||||||
def get_gpu_memory(gpu: int = 0) -> int:
|
def get_gpu_memory(gpu: int = 0) -> int:
|
||||||
"""Returns the total memory of the GPU in bytes."""
|
"""Returns the total memory of the GPU in bytes."""
|
||||||
return torch.cuda.get_device_properties(gpu).total_memory
|
return torch.cuda.get_device_properties(gpu).total_memory
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
|||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||||
from vllm.worker.cache_engine import CacheEngine
|
from vllm.worker.cache_engine import CacheEngine
|
||||||
from vllm.utils import get_gpu_memory
|
from vllm.utils import get_gpu_memory, get_max_shared_memory_bytes
|
||||||
|
|
||||||
|
|
||||||
class Worker:
|
class Worker:
|
||||||
@ -136,6 +136,10 @@ class Worker:
|
|||||||
def init_cache_engine(self, cache_config: CacheConfig) -> None:
|
def init_cache_engine(self, cache_config: CacheConfig) -> None:
|
||||||
self.cache_config = cache_config
|
self.cache_config = cache_config
|
||||||
self.block_size = cache_config.block_size
|
self.block_size = cache_config.block_size
|
||||||
|
|
||||||
|
_check_if_can_support_max_seq_len(self.scheduler_config.max_model_len,
|
||||||
|
self.block_size)
|
||||||
|
|
||||||
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
|
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
|
||||||
self.parallel_config)
|
self.parallel_config)
|
||||||
self.cache_events = self.cache_engine.events
|
self.cache_events = self.cache_engine.events
|
||||||
@ -347,3 +351,23 @@ def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]:
|
|||||||
|
|
||||||
def _pad_to_max(x: List[int], max_len: int) -> List[int]:
|
def _pad_to_max(x: List[int], max_len: int) -> List[int]:
|
||||||
return x + [0] * (max_len - len(x))
|
return x + [0] * (max_len - len(x))
|
||||||
|
|
||||||
|
|
||||||
|
def _check_if_can_support_max_seq_len(max_seq_len: int,
|
||||||
|
block_size: int) -> None:
|
||||||
|
# Follows the logic in
|
||||||
|
# attention_kernels.cu::single_query_cached_kv_attention_launcher
|
||||||
|
max_shared_mem = get_max_shared_memory_bytes()
|
||||||
|
float32_bytes = torch.finfo(torch.float).bits // 8
|
||||||
|
padded_max_seq_len = (
|
||||||
|
(max_seq_len + block_size - 1) / block_size) * block_size
|
||||||
|
# padded_max_seq_len + extra buffer
|
||||||
|
required_shared_mem = (padded_max_seq_len + 512) * float32_bytes
|
||||||
|
if padded_max_seq_len * float32_bytes > max_shared_mem:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"vLLM cannot currently support max_model_len={max_seq_len} "
|
||||||
|
f"with block_size={block_size} on GPU with compute "
|
||||||
|
f"capability {torch.cuda.get_device_capability()} "
|
||||||
|
f"(required shared memory {required_shared_mem} > "
|
||||||
|
f"available shared memory {max_shared_mem}). "
|
||||||
|
"This will be fixed in a future release.")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user