mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-27 06:47:02 +08:00
mla format
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
44a595f6d6
commit
d4b502a73a
@ -206,12 +206,11 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv, round_down
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
|
||||
slice_query_start_locs)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
from vllm.v1.attention.backends.utils import slice_query_start_locs
|
||||
|
||||
try:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
is_vllm_fa = True
|
||||
@ -442,8 +441,10 @@ class MLACommonMetadataBuilder(Generic[M]):
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens,
|
||||
)
|
||||
|
||||
def _split_decodes_and_prefills(self, max_query_len: int, num_reqs: int, num_tokens: int, query_start_loc: torch.Tensor):
|
||||
|
||||
def _split_decodes_and_prefills(self, max_query_len: int, num_reqs: int,
|
||||
num_tokens: int,
|
||||
query_start_loc: torch.Tensor):
|
||||
"""
|
||||
return
|
||||
- num_decodes: number of decode requests
|
||||
@ -462,16 +463,17 @@ class MLACommonMetadataBuilder(Generic[M]):
|
||||
num_prefills = num_reqs - num_decodes
|
||||
num_decode_tokens = first_prefill
|
||||
num_prefill_tokens = num_tokens - query_start_loc[first_prefill]
|
||||
return (
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens
|
||||
)
|
||||
|
||||
def build_slice(self, req_slice: slice,
|
||||
token_slice: slice,
|
||||
max_query_len: int,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> M:
|
||||
return (num_decodes, num_prefills, num_decode_tokens,
|
||||
num_prefill_tokens)
|
||||
|
||||
def build_slice(
|
||||
self,
|
||||
req_slice: slice,
|
||||
token_slice: slice,
|
||||
max_query_len: int,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> M:
|
||||
num_reqs = req_slice.stop - req_slice.start
|
||||
num_tokens = token_slice.stop - token_slice.start
|
||||
|
||||
@ -481,14 +483,13 @@ class MLACommonMetadataBuilder(Generic[M]):
|
||||
device = self.runner.device
|
||||
block_table = self.block_table
|
||||
block_table_tensor = block_table.get_device_tensor()[req_slice]
|
||||
# print(f"num_reqs: {num_reqs} bloc_table_shape: {block_table_tensor.shape}")
|
||||
slot_mapping = block_table.slot_mapping_cpu[token_slice].to(
|
||||
device, non_blocking=True).long()
|
||||
|
||||
query_start_loc = slice_query_start_locs(
|
||||
common_attn_metadata.query_start_loc, req_slice)
|
||||
seq_lens = common_attn_metadata.seq_lens[req_slice]
|
||||
|
||||
|
||||
num_computed_tokens = self.runner.input_batch.\
|
||||
num_computed_tokens_cpu_tensor[req_slice]
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user