mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-21 01:46:59 +08:00
[ROCm] Faster Custom Paged Attention kernels (#12348)
This commit is contained in:
parent
98175b2816
commit
848a6438ae
@ -77,7 +77,6 @@ echo "Commands:$commands"
|
|||||||
#ignore certain kernels tests
|
#ignore certain kernels tests
|
||||||
if [[ $commands == *" kernels "* ]]; then
|
if [[ $commands == *" kernels "* ]]; then
|
||||||
commands="${commands} \
|
commands="${commands} \
|
||||||
--ignore=kernels/test_attention.py \
|
|
||||||
--ignore=kernels/test_attention_selector.py \
|
--ignore=kernels/test_attention_selector.py \
|
||||||
--ignore=kernels/test_blocksparse_attention.py \
|
--ignore=kernels/test_blocksparse_attention.py \
|
||||||
--ignore=kernels/test_causal_conv1d.py \
|
--ignore=kernels/test_causal_conv1d.py \
|
||||||
|
|||||||
@ -11,8 +11,9 @@ from vllm.platforms import current_platform
|
|||||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
|
||||||
create_kv_caches_with_random)
|
create_kv_caches_with_random)
|
||||||
|
|
||||||
NUM_BLOCKS = 1024
|
NUM_BLOCKS = 128 * 1024
|
||||||
PARTITION_SIZE = 512
|
PARTITION_SIZE = 512
|
||||||
|
PARTITION_SIZE_ROCM = 256
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@ -80,6 +81,12 @@ def main(
|
|||||||
# Prepare for the paged attention kernel.
|
# Prepare for the paged attention kernel.
|
||||||
output = torch.empty_like(query)
|
output = torch.empty_like(query)
|
||||||
if version == "v2":
|
if version == "v2":
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
global PARTITION_SIZE
|
||||||
|
if not args.custom_paged_attn:
|
||||||
|
PARTITION_SIZE = 1024
|
||||||
|
else:
|
||||||
|
PARTITION_SIZE = PARTITION_SIZE_ROCM
|
||||||
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
|
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
|
||||||
tmp_output = torch.empty(
|
tmp_output = torch.empty(
|
||||||
size=(num_seqs, num_query_heads, num_partitions, head_size),
|
size=(num_seqs, num_query_heads, num_partitions, head_size),
|
||||||
@ -123,25 +130,46 @@ def main(
|
|||||||
v_scale,
|
v_scale,
|
||||||
)
|
)
|
||||||
elif version == "v2":
|
elif version == "v2":
|
||||||
ops.paged_attention_v2(
|
if not args.custom_paged_attn:
|
||||||
output,
|
ops.paged_attention_v2(
|
||||||
exp_sums,
|
output,
|
||||||
max_logits,
|
exp_sums,
|
||||||
tmp_output,
|
max_logits,
|
||||||
query,
|
tmp_output,
|
||||||
key_cache,
|
query,
|
||||||
value_cache,
|
key_cache,
|
||||||
num_kv_heads,
|
value_cache,
|
||||||
scale,
|
num_kv_heads,
|
||||||
block_tables,
|
scale,
|
||||||
seq_lens,
|
block_tables,
|
||||||
block_size,
|
seq_lens,
|
||||||
max_seq_len,
|
block_size,
|
||||||
alibi_slopes,
|
max_seq_len,
|
||||||
kv_cache_dtype,
|
alibi_slopes,
|
||||||
k_scale,
|
kv_cache_dtype,
|
||||||
v_scale,
|
k_scale,
|
||||||
)
|
v_scale,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ops.paged_attention_rocm(
|
||||||
|
output,
|
||||||
|
exp_sums,
|
||||||
|
max_logits,
|
||||||
|
tmp_output,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
num_kv_heads,
|
||||||
|
scale,
|
||||||
|
block_tables,
|
||||||
|
seq_lens,
|
||||||
|
block_size,
|
||||||
|
max_seq_len,
|
||||||
|
alibi_slopes,
|
||||||
|
kv_cache_dtype,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid version: {version}")
|
raise ValueError(f"Invalid version: {version}")
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
@ -195,6 +223,9 @@ if __name__ == '__main__':
|
|||||||
help="Data type for kv cache storage. If 'auto', will use model "
|
help="Data type for kv cache storage. If 'auto', will use model "
|
||||||
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
|
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
|
||||||
"ROCm (AMD GPU) supports fp8 (=fp8_e4m3)")
|
"ROCm (AMD GPU) supports fp8 (=fp8_e4m3)")
|
||||||
|
parser.add_argument("--custom-paged-attn",
|
||||||
|
action="store_true",
|
||||||
|
help="Use custom paged attention")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
print(args)
|
print(args)
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -25,6 +25,7 @@ MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
|
|||||||
# Reduce NUM_BLOCKS when it happens.
|
# Reduce NUM_BLOCKS when it happens.
|
||||||
NUM_BLOCKS = 4321 # Arbitrary values for testing
|
NUM_BLOCKS = 4321 # Arbitrary values for testing
|
||||||
PARTITION_SIZE = 512
|
PARTITION_SIZE = 512
|
||||||
|
PARTITION_SIZE_ROCM = 256
|
||||||
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
|
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
|
||||||
DTYPES = [
|
DTYPES = [
|
||||||
torch.half, torch.bfloat16, torch.float
|
torch.half, torch.bfloat16, torch.float
|
||||||
@ -146,6 +147,8 @@ def test_paged_attention(
|
|||||||
or (version == "rocm" and head_size not in (64, 128))):
|
or (version == "rocm" and head_size not in (64, 128))):
|
||||||
pytest.skip()
|
pytest.skip()
|
||||||
|
|
||||||
|
global PARTITION_SIZE
|
||||||
|
|
||||||
current_platform.seed_everything(seed)
|
current_platform.seed_everything(seed)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
scale = float(1.0 / (head_size**0.5))
|
scale = float(1.0 / (head_size**0.5))
|
||||||
@ -214,6 +217,9 @@ def test_paged_attention(
|
|||||||
and block_size == BLOCK_SIZES[0]))
|
and block_size == BLOCK_SIZES[0]))
|
||||||
|
|
||||||
elif version in ("v2", "rocm"):
|
elif version in ("v2", "rocm"):
|
||||||
|
if current_platform.is_rocm() and version == "rocm":
|
||||||
|
PARTITION_SIZE = PARTITION_SIZE_ROCM
|
||||||
|
|
||||||
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
|
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
|
||||||
assert PARTITION_SIZE % block_size == 0
|
assert PARTITION_SIZE % block_size == 0
|
||||||
num_seqs, num_heads, head_size = output.shape
|
num_seqs, num_heads, head_size = output.shape
|
||||||
|
|||||||
@ -22,7 +22,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_PARTITION_SIZE_ROCM = 512
|
_PARTITION_SIZE_ROCM = 256
|
||||||
_GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
_GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||||
_ON_NAVI = "gfx1" in _GPU_ARCH
|
_ON_NAVI = "gfx1" in _GPU_ARCH
|
||||||
_ON_MI250_MI300 = any(arch in _GPU_ARCH for arch in ["gfx90a", "gfx942"])
|
_ON_MI250_MI300 = any(arch in _GPU_ARCH for arch in ["gfx90a", "gfx942"])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user