mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-12 21:05:53 +08:00
[Attention] Optimize FlashInfer MetadataBuilder Build call (#21137)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
526078a96c
commit
61b8cea3b4
@ -11,7 +11,8 @@ from tests.v1.attention.utils import (BatchSpec, _Backend,
|
|||||||
create_vllm_config,
|
create_vllm_config,
|
||||||
get_attention_backend)
|
get_attention_backend)
|
||||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
||||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
|
||||||
|
set_kv_cache_layout)
|
||||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||||
|
|
||||||
BACKENDS_TO_TEST = [
|
BACKENDS_TO_TEST = [
|
||||||
@ -212,7 +213,7 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
|
|||||||
|
|
||||||
from vllm.v1.attention.backends.flashinfer import PerLayerParameters
|
from vllm.v1.attention.backends.flashinfer import PerLayerParameters
|
||||||
|
|
||||||
def mock_get_per_layer_parameters(vllm_config):
|
def mock_get_per_layer_parameters(vllm_config, impl_cls):
|
||||||
# Return mock parameters for a single layer
|
# Return mock parameters for a single layer
|
||||||
head_size = vllm_config.model_config.get_head_size()
|
head_size = vllm_config.model_config.get_head_size()
|
||||||
return {
|
return {
|
||||||
@ -297,7 +298,8 @@ def test_backend_correctness(batch_spec_name: str, model: str):
|
|||||||
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
|
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
|
||||||
"""
|
"""
|
||||||
batch_spec = BATCH_SPECS[batch_spec_name]
|
batch_spec = BATCH_SPECS[batch_spec_name]
|
||||||
vllm_config = create_vllm_config(model_name=model)
|
vllm_config = create_vllm_config(model_name=model,
|
||||||
|
max_model_len=max(batch_spec.seq_lens))
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
|
|
||||||
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
|
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
|
||||||
@ -419,6 +421,11 @@ def test_backend_correctness(batch_spec_name: str, model: str):
|
|||||||
if backend_name == _Backend.FLASHINFER_VLLM_V1:
|
if backend_name == _Backend.FLASHINFER_VLLM_V1:
|
||||||
kv_cache_for_backend = kv_cache.transpose(0, 1)
|
kv_cache_for_backend = kv_cache.transpose(0, 1)
|
||||||
|
|
||||||
|
# For FlashInfer default to HND layout and
|
||||||
|
kv_cache_for_backend = kv_cache_for_backend.transpose(
|
||||||
|
2, 3).contiguous().transpose(2, 3)
|
||||||
|
set_kv_cache_layout("HND")
|
||||||
|
|
||||||
backend_output = run_attention_backend(backend_name, kv_cache_spec,
|
backend_output = run_attention_backend(backend_name, kv_cache_spec,
|
||||||
vllm_config, device,
|
vllm_config, device,
|
||||||
common_attn_metadata,
|
common_attn_metadata,
|
||||||
|
|||||||
@ -66,7 +66,7 @@ def create_common_attn_metadata(
|
|||||||
num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32)
|
num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32)
|
||||||
|
|
||||||
# Create block table (random for testing)
|
# Create block table (random for testing)
|
||||||
max_blocks = max(batch_spec.seq_lens) // block_size + 1
|
max_blocks = (max(batch_spec.seq_lens) + block_size - 1) // block_size
|
||||||
block_table_tensor = torch.randint(0,
|
block_table_tensor = torch.randint(0,
|
||||||
max_block_idx,
|
max_block_idx,
|
||||||
(batch_spec.batch_size, max_blocks),
|
(batch_spec.batch_size, max_blocks),
|
||||||
|
|||||||
@ -18,6 +18,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import cdiv
|
||||||
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
AttentionMetadataBuilder, CommonAttentionMetadata, PerLayerParameters,
|
AttentionMetadataBuilder, CommonAttentionMetadata, PerLayerParameters,
|
||||||
@ -158,7 +159,7 @@ class FlashInferMetadata:
|
|||||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||||
# the batch, used to index into subquery. E.g., if the subquery length
|
# the batch, used to index into subquery. E.g., if the subquery length
|
||||||
# is [4, 6], it is [0, 4, 10].
|
# is [4, 6], it is [0, 4, 10].
|
||||||
qo_indptr: torch.Tensor
|
qo_indptr_cpu: torch.Tensor
|
||||||
# An example for paged_kv_indices, paged_kv_indptr:
|
# An example for paged_kv_indices, paged_kv_indptr:
|
||||||
# request 1, page indices [0, 5, 8]
|
# request 1, page indices [0, 5, 8]
|
||||||
# request 2, page indices [1, 6, 7]
|
# request 2, page indices [1, 6, 7]
|
||||||
@ -167,13 +168,13 @@ class FlashInferMetadata:
|
|||||||
# [0, 5, 8, 1, 6, 7, 3, 4]
|
# [0, 5, 8, 1, 6, 7, 3, 4]
|
||||||
# paged_kv_indptr is used to index into paged_kv_indices:
|
# paged_kv_indptr is used to index into paged_kv_indices:
|
||||||
# [0, 3, 6, 8]
|
# [0, 3, 6, 8]
|
||||||
# The indptr of the paged kv cache, shape: [batch_size + 1]
|
# The indptr of the paged kv cache, shape: [batch_size + 1] (CPU for plan)
|
||||||
paged_kv_indptr: torch.Tensor
|
paged_kv_indptr_cpu: torch.Tensor
|
||||||
# The page indices of the paged kv cache
|
# The page indices of the paged kv cache (on device for plan)
|
||||||
paged_kv_indices: torch.Tensor
|
paged_kv_indices: torch.Tensor
|
||||||
# The number of entries in the last page of each request in
|
# The number of entries in the last page of each request in
|
||||||
# the paged kv cache, shape: [batch_size]
|
# the paged kv cache, shape: [batch_size] (CPU for plan)
|
||||||
paged_kv_last_page_len: torch.Tensor
|
paged_kv_last_page_len_cpu: torch.Tensor
|
||||||
# The number of query/output heads
|
# The number of query/output heads
|
||||||
num_qo_heads: int
|
num_qo_heads: int
|
||||||
# The number of key/value heads
|
# The number of key/value heads
|
||||||
@ -201,22 +202,17 @@ class FlashInferMetadata:
|
|||||||
num_prefills: int
|
num_prefills: int
|
||||||
num_prefill_tokens: int
|
num_prefill_tokens: int
|
||||||
|
|
||||||
# For cascade attention.
|
# For cascade attention (CPU for planning).
|
||||||
use_cascade: bool
|
use_cascade: bool
|
||||||
shared_qo_indptr: Optional[torch.Tensor] = None
|
shared_qo_indptr_cpu: Optional[torch.Tensor] = None
|
||||||
shared_kv_page_indptr: Optional[torch.Tensor] = None
|
shared_kv_page_indptr_cpu: Optional[torch.Tensor] = None
|
||||||
shared_kv_page_indices: Optional[torch.Tensor] = None
|
shared_kv_page_indices_cpu: Optional[torch.Tensor] = None
|
||||||
shared_kv_last_page_len: Optional[torch.Tensor] = None
|
shared_kv_last_page_len_cpu: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
|
prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
|
||||||
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
|
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
|
||||||
cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None
|
cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None
|
||||||
|
|
||||||
@property
|
|
||||||
def query_start_loc(self):
|
|
||||||
# The GPUModelRunner expects to be able to access this property.
|
|
||||||
return self.qo_indptr
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.head_dim is not None:
|
if self.head_dim is not None:
|
||||||
FlashInferBackend.validate_head_size(self.head_dim)
|
FlashInferBackend.validate_head_size(self.head_dim)
|
||||||
@ -238,6 +234,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
self.cache_config = vllm_config.cache_config
|
self.cache_config = vllm_config.cache_config
|
||||||
self.kv_cache_spec = kv_cache_spec
|
self.kv_cache_spec = kv_cache_spec
|
||||||
|
max_num_blocks_per_request = cdiv(
|
||||||
|
vllm_config.model_config.max_model_len,
|
||||||
|
self.kv_cache_spec.block_size)
|
||||||
|
self.block_table_arange = torch.arange(max_num_blocks_per_request,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
def reorder_batch(self, input_batch: InputBatch,
|
def reorder_batch(self, input_batch: InputBatch,
|
||||||
scheduler_output: SchedulerOutput) -> bool:
|
scheduler_output: SchedulerOutput) -> bool:
|
||||||
@ -285,21 +287,25 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
if self.global_hyperparameters is None:
|
if self.global_hyperparameters is None:
|
||||||
self.global_hyperparameters = infer_global_hyperparameters(
|
self.global_hyperparameters = infer_global_hyperparameters(
|
||||||
get_per_layer_parameters(self.vllm_config, FlashInferImpl))
|
get_per_layer_parameters(self.vllm_config, FlashInferImpl))
|
||||||
|
|
||||||
if attn_metadata.use_cascade:
|
if attn_metadata.use_cascade:
|
||||||
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
|
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
|
||||||
attn_metadata.cascade_wrapper.plan(
|
attn_metadata.cascade_wrapper.plan(
|
||||||
[attn_metadata.shared_qo_indptr, attn_metadata.qo_indptr],
|
|
||||||
[
|
[
|
||||||
attn_metadata.shared_kv_page_indptr,
|
attn_metadata.shared_qo_indptr_cpu,
|
||||||
attn_metadata.paged_kv_indptr
|
attn_metadata.qo_indptr_cpu
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
attn_metadata.shared_kv_page_indices,
|
attn_metadata.shared_kv_page_indptr_cpu,
|
||||||
|
attn_metadata.paged_kv_indptr_cpu
|
||||||
|
],
|
||||||
|
[
|
||||||
|
attn_metadata.shared_kv_page_indices_cpu,
|
||||||
attn_metadata.paged_kv_indices
|
attn_metadata.paged_kv_indices
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
attn_metadata.shared_kv_last_page_len,
|
attn_metadata.shared_kv_last_page_len_cpu,
|
||||||
attn_metadata.paged_kv_last_page_len
|
attn_metadata.paged_kv_last_page_len_cpu
|
||||||
],
|
],
|
||||||
attn_metadata.num_qo_heads,
|
attn_metadata.num_qo_heads,
|
||||||
attn_metadata.num_kv_heads,
|
attn_metadata.num_kv_heads,
|
||||||
@ -320,22 +326,22 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
# Decodes are first so prefills start after the last decode
|
# Decodes are first so prefills start after the last decode
|
||||||
prefill_start = num_decodes
|
prefill_start = num_decodes
|
||||||
attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
|
attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
|
||||||
assert attn_metadata.qo_indptr[prefill_start:].shape[
|
assert attn_metadata.qo_indptr_cpu[prefill_start:].shape[
|
||||||
0] == num_prefills + 1
|
0] == num_prefills + 1
|
||||||
assert attn_metadata.paged_kv_indptr[prefill_start:].shape[
|
assert attn_metadata.paged_kv_indptr_cpu[prefill_start:].shape[
|
||||||
0] == num_prefills + 1
|
0] == num_prefills + 1
|
||||||
assert attn_metadata.paged_kv_last_page_len[
|
assert attn_metadata.paged_kv_last_page_len_cpu[
|
||||||
prefill_start:].shape[0] == num_prefills
|
prefill_start:].shape[0] == num_prefills
|
||||||
# Since prefill_wrapper.run() will be called with
|
# Since prefill_wrapper.run() will be called with
|
||||||
# query[num_decode_tokens:] we need to adjust the qo_indptr
|
# query[num_decode_tokens:] we need to adjust the qo_indptr
|
||||||
# to be relative to the start of the prefill queries.
|
# to be relative to the start of the prefill queries.
|
||||||
qo_indptr = attn_metadata.qo_indptr[
|
qo_indptr_cpu = attn_metadata.qo_indptr_cpu[
|
||||||
prefill_start:] - attn_metadata.qo_indptr[prefill_start]
|
prefill_start:] - attn_metadata.qo_indptr_cpu[prefill_start]
|
||||||
attn_metadata.prefill_wrapper.plan(
|
attn_metadata.prefill_wrapper.plan(
|
||||||
qo_indptr,
|
qo_indptr_cpu,
|
||||||
attn_metadata.paged_kv_indptr[prefill_start:],
|
attn_metadata.paged_kv_indptr_cpu[prefill_start:],
|
||||||
attn_metadata.paged_kv_indices,
|
attn_metadata.paged_kv_indices,
|
||||||
attn_metadata.paged_kv_last_page_len[prefill_start:],
|
attn_metadata.paged_kv_last_page_len_cpu[prefill_start:],
|
||||||
attn_metadata.num_qo_heads,
|
attn_metadata.num_qo_heads,
|
||||||
attn_metadata.num_kv_heads,
|
attn_metadata.num_kv_heads,
|
||||||
attn_metadata.head_dim,
|
attn_metadata.head_dim,
|
||||||
@ -357,9 +363,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
attn_metadata.num_qo_heads, attn_metadata.num_kv_heads,
|
attn_metadata.num_qo_heads, attn_metadata.num_kv_heads,
|
||||||
attn_metadata.head_dim):
|
attn_metadata.head_dim):
|
||||||
attn_metadata.decode_wrapper.plan(
|
attn_metadata.decode_wrapper.plan(
|
||||||
attn_metadata.paged_kv_indptr[:num_decodes + 1],
|
attn_metadata.paged_kv_indptr_cpu[:num_decodes + 1],
|
||||||
attn_metadata.paged_kv_indices,
|
attn_metadata.paged_kv_indices,
|
||||||
attn_metadata.paged_kv_last_page_len[:num_decodes],
|
attn_metadata.paged_kv_last_page_len_cpu[:num_decodes],
|
||||||
attn_metadata.num_qo_heads,
|
attn_metadata.num_qo_heads,
|
||||||
attn_metadata.num_kv_heads,
|
attn_metadata.num_kv_heads,
|
||||||
attn_metadata.head_dim,
|
attn_metadata.head_dim,
|
||||||
@ -383,55 +389,58 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
split_decodes_and_prefills(common_attn_metadata)
|
split_decodes_and_prefills(common_attn_metadata)
|
||||||
|
|
||||||
page_size = self.kv_cache_spec.block_size
|
page_size = self.kv_cache_spec.block_size
|
||||||
device = self.device
|
|
||||||
qo_indptr = common_attn_metadata.query_start_loc
|
|
||||||
max_seq_len = common_attn_metadata.seq_lens_cpu.max()
|
max_seq_len = common_attn_metadata.seq_lens_cpu.max()
|
||||||
seq_lens = common_attn_metadata.seq_lens
|
seq_lens = common_attn_metadata.seq_lens
|
||||||
|
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||||
|
|
||||||
block_table_bounds = (seq_lens + page_size - 1) // page_size
|
block_table_bounds_cpu = (seq_lens_cpu + page_size - 1) // page_size
|
||||||
|
|
||||||
use_cascade = common_prefix_len > 0
|
use_cascade = common_prefix_len > 0
|
||||||
if use_cascade:
|
if use_cascade:
|
||||||
# Grab the blocks of the shared prefix from the first request.
|
# Grab the blocks of the shared prefix from the first request.
|
||||||
assert common_prefix_len % page_size == 0
|
assert common_prefix_len % page_size == 0
|
||||||
num_common_kv_blocks = common_prefix_len // page_size
|
num_common_kv_blocks = common_prefix_len // page_size
|
||||||
shared_qo_indptr = torch.tensor([0, num_actual_tokens],
|
|
||||||
|
# Create CPU versions directly for cascade (no GPU versions needed)
|
||||||
|
shared_qo_indptr_cpu = torch.tensor([0, num_actual_tokens],
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=device)
|
device='cpu')
|
||||||
shared_kv_page_indptr = torch.tensor([0, num_common_kv_blocks],
|
shared_kv_page_indptr_cpu = torch.tensor([0, num_common_kv_blocks],
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=device)
|
device='cpu')
|
||||||
shared_kv_page_indices = block_table_tensor[
|
shared_kv_page_indices_cpu = block_table_tensor[
|
||||||
0, :num_common_kv_blocks]
|
0, :num_common_kv_blocks]
|
||||||
shared_kv_last_page_len = torch.tensor([page_size],
|
shared_kv_last_page_len_cpu = torch.tensor([page_size],
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=device)
|
device='cpu')
|
||||||
|
|
||||||
# Remove the blocks of the shared prefix from all requests.
|
# Remove the blocks of the shared prefix from all requests.
|
||||||
block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
|
block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
|
||||||
block_table_bounds -= num_common_kv_blocks
|
block_table_bounds_cpu -= num_common_kv_blocks
|
||||||
else:
|
else:
|
||||||
shared_qo_indptr = None
|
shared_qo_indptr_cpu = None
|
||||||
shared_kv_page_indptr = None
|
shared_kv_page_indptr_cpu = None
|
||||||
shared_kv_page_indices = None
|
shared_kv_page_indices_cpu = None
|
||||||
shared_kv_last_page_len = None
|
shared_kv_last_page_len_cpu = None
|
||||||
|
|
||||||
mask = (torch.arange(block_table_tensor.size(1),
|
max_num_blocks = block_table_bounds_cpu.max()
|
||||||
dtype=block_table_tensor.dtype,
|
block_table_bounds = block_table_bounds_cpu.to(self.device,
|
||||||
device=block_table_tensor.device).unsqueeze(0)
|
non_blocking=True)
|
||||||
|
mask = (self.block_table_arange[:max_num_blocks].unsqueeze(0)
|
||||||
< block_table_bounds.unsqueeze(1))
|
< block_table_bounds.unsqueeze(1))
|
||||||
paged_kv_indices = block_table_tensor[mask]
|
paged_kv_indices = block_table_tensor[:, :max_num_blocks][mask]
|
||||||
|
|
||||||
paged_kv_indptr = torch.cat([
|
paged_kv_indptr_cpu = torch.zeros(len(block_table_bounds_cpu) + 1,
|
||||||
torch.zeros(1,
|
dtype=torch.int32,
|
||||||
dtype=block_table_bounds.dtype,
|
device='cpu')
|
||||||
device=block_table_bounds.device),
|
paged_kv_indptr_cpu[1:] = block_table_bounds_cpu.cumsum(
|
||||||
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
|
dim=0, dtype=torch.int32)
|
||||||
])
|
|
||||||
|
|
||||||
paged_kv_last_page_len = seq_lens % page_size
|
paged_kv_last_page_len_cpu = seq_lens_cpu % page_size
|
||||||
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
|
paged_kv_last_page_len_cpu = torch.where(
|
||||||
page_size, paged_kv_last_page_len)
|
paged_kv_last_page_len_cpu == 0, page_size,
|
||||||
|
paged_kv_last_page_len_cpu)
|
||||||
cache_dtype = self.cache_config.cache_dtype
|
cache_dtype = self.cache_config.cache_dtype
|
||||||
if cache_dtype.startswith("fp8"):
|
if cache_dtype.startswith("fp8"):
|
||||||
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||||
@ -440,10 +449,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
kv_cache_dtype = self.kv_cache_spec.dtype
|
kv_cache_dtype = self.kv_cache_spec.dtype
|
||||||
attn_metadata = FlashInferMetadata(
|
attn_metadata = FlashInferMetadata(
|
||||||
num_actual_tokens=num_actual_tokens,
|
num_actual_tokens=num_actual_tokens,
|
||||||
qo_indptr=qo_indptr,
|
qo_indptr_cpu=common_attn_metadata.query_start_loc_cpu,
|
||||||
paged_kv_indptr=paged_kv_indptr,
|
paged_kv_indptr_cpu=paged_kv_indptr_cpu,
|
||||||
paged_kv_indices=paged_kv_indices,
|
paged_kv_indices=paged_kv_indices,
|
||||||
paged_kv_last_page_len=paged_kv_last_page_len,
|
paged_kv_last_page_len_cpu=paged_kv_last_page_len_cpu,
|
||||||
num_qo_heads=self.vllm_config.model_config.get_num_attention_heads(
|
num_qo_heads=self.vllm_config.model_config.get_num_attention_heads(
|
||||||
self.vllm_config.parallel_config),
|
self.vllm_config.parallel_config),
|
||||||
num_kv_heads=self.kv_cache_spec.num_kv_heads,
|
num_kv_heads=self.kv_cache_spec.num_kv_heads,
|
||||||
@ -457,14 +466,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
num_prefills=num_prefills,
|
num_prefills=num_prefills,
|
||||||
num_prefill_tokens=num_prefill_tokens,
|
num_prefill_tokens=num_prefill_tokens,
|
||||||
use_cascade=use_cascade,
|
use_cascade=use_cascade,
|
||||||
shared_qo_indptr=shared_qo_indptr,
|
shared_qo_indptr_cpu=shared_qo_indptr_cpu,
|
||||||
shared_kv_page_indptr=shared_kv_page_indptr,
|
shared_kv_page_indptr_cpu=shared_kv_page_indptr_cpu,
|
||||||
shared_kv_page_indices=shared_kv_page_indices,
|
shared_kv_page_indices_cpu=shared_kv_page_indices_cpu,
|
||||||
shared_kv_last_page_len=shared_kv_last_page_len,
|
shared_kv_last_page_len_cpu=shared_kv_last_page_len_cpu,
|
||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
block_table_tensor=block_table_tensor,
|
block_table_tensor=block_table_tensor,
|
||||||
workspace_buffer=self._workspace_buffer,
|
workspace_buffer=self._get_workspace_buffer(),
|
||||||
)
|
)
|
||||||
|
|
||||||
self._plan(num_prefills, num_decodes, attn_metadata)
|
self._plan(num_prefills, num_decodes, attn_metadata)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user