mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-14 08:57:13 +08:00
[BUG] fix crash on flashinfer backend with cudagraph disabled, when attention group_size not in [1,2,4,8] (#7509)
This commit is contained in:
parent
c75363fbc0
commit
53328d7536
@ -4,7 +4,7 @@ import flashinfer
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
NUM_HEADS = [(16, 16), (32, 8), (64, 8)]
|
NUM_HEADS = [(16, 16), (32, 8), (64, 8), (6, 1)]
|
||||||
HEAD_SIZES = [128, 256]
|
HEAD_SIZES = [128, 256]
|
||||||
BLOCK_SIZES = [16, 32]
|
BLOCK_SIZES = [16, 32]
|
||||||
DTYPES = [torch.float16, torch.bfloat16]
|
DTYPES = [torch.float16, torch.bfloat16]
|
||||||
@ -123,7 +123,10 @@ def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
|
|||||||
|
|
||||||
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
||||||
wrapper = flashinfer.\
|
wrapper = flashinfer.\
|
||||||
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD")
|
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD",
|
||||||
|
use_tensor_cores=(
|
||||||
|
(num_query_heads//num_kv_heads) not in (1, 2, 4, 8))
|
||||||
|
)
|
||||||
wrapper.begin_forward(kv_indptr,
|
wrapper.begin_forward(kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
kv_last_page_lens,
|
kv_last_page_lens,
|
||||||
|
|||||||
@ -113,7 +113,8 @@ class FlashInferState(AttentionState):
|
|||||||
self.runner.parallel_config))
|
self.runner.parallel_config))
|
||||||
num_kv_heads = self.runner.model_config.get_num_kv_heads(
|
num_kv_heads = self.runner.model_config.get_num_kv_heads(
|
||||||
self.runner.parallel_config)
|
self.runner.parallel_config)
|
||||||
use_tensor_cores = num_qo_heads // num_kv_heads >= 4
|
use_tensor_cores = (num_qo_heads // num_kv_heads) not in \
|
||||||
|
(1, 2, 4, 8)
|
||||||
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||||
self._get_workspace_buffer(),
|
self._get_workspace_buffer(),
|
||||||
"NHD",
|
"NHD",
|
||||||
@ -171,7 +172,8 @@ class FlashInferState(AttentionState):
|
|||||||
self.runner.parallel_config))
|
self.runner.parallel_config))
|
||||||
num_kv_heads = self.runner.model_config.get_num_kv_heads(
|
num_kv_heads = self.runner.model_config.get_num_kv_heads(
|
||||||
self.runner.parallel_config)
|
self.runner.parallel_config)
|
||||||
use_tensor_cores = num_qo_heads // num_kv_heads >= 4
|
use_tensor_cores = (num_qo_heads // num_kv_heads) not in \
|
||||||
|
(1, 2, 4, 8)
|
||||||
self._graph_decode_wrapper = \
|
self._graph_decode_wrapper = \
|
||||||
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
|
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
|
||||||
self._graph_decode_workspace_buffer, _indptr_buffer,
|
self._graph_decode_workspace_buffer, _indptr_buffer,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user