mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-15 02:13:30 +08:00
misc
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
d833982e48
commit
57d404bbb8
@ -475,6 +475,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
max_query_len: int,
|
max_query_len: int,
|
||||||
common_prefix_len: int,
|
common_prefix_len: int,
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
ubatch_id: int = 0
|
||||||
) -> M:
|
) -> M:
|
||||||
num_reqs = req_slice.stop - req_slice.start
|
num_reqs = req_slice.stop - req_slice.start
|
||||||
num_tokens = token_slice.stop - token_slice.start
|
num_tokens = token_slice.stop - token_slice.start
|
||||||
@ -586,6 +587,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
decode_metadata = self._build_decode(
|
decode_metadata = self._build_decode(
|
||||||
block_table_tensor=block_table_tensor[:num_decodes, ...],
|
block_table_tensor=block_table_tensor[:num_decodes, ...],
|
||||||
seq_lens=seq_lens[:num_decodes],
|
seq_lens=seq_lens[:num_decodes],
|
||||||
|
ubatch_id=ubatch_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.metadata_cls(
|
return self.metadata_cls(
|
||||||
|
|||||||
@ -63,11 +63,12 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
|||||||
self.num_q_heads = self.runner.model_config.get_num_attention_heads(
|
self.num_q_heads = self.runner.model_config.get_num_attention_heads(
|
||||||
self.runner.parallel_config)
|
self.runner.parallel_config)
|
||||||
|
|
||||||
self.cg_buf_tile_scheduler_metadata = None
|
self.cg_buf_tile_scheduler_metadata = [None, None]
|
||||||
self.cg_buf_num_splits = None
|
self.cg_buf_num_splits = [None, None]
|
||||||
|
|
||||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||||
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
|
seq_lens: torch.Tensor, ubatch_id = 0) -> FlashMLADecodeMetadata:
|
||||||
|
assert ubatch_id < 2
|
||||||
tile_scheduler_metadata, num_splits = \
|
tile_scheduler_metadata, num_splits = \
|
||||||
get_mla_metadata(
|
get_mla_metadata(
|
||||||
seq_lens,
|
seq_lens,
|
||||||
@ -79,27 +80,27 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
|||||||
if self.runner.full_cuda_graph:
|
if self.runner.full_cuda_graph:
|
||||||
n = num_splits.size(0)
|
n = num_splits.size(0)
|
||||||
# First time around (CUDAGraph capture), allocate the static buffer
|
# First time around (CUDAGraph capture), allocate the static buffer
|
||||||
if self.cg_buf_num_splits is None:
|
if self.cg_buf_num_splits[ubatch_id] is None:
|
||||||
self.cg_buf_num_splits = num_splits
|
self.cg_buf_num_splits[ubatch_id] = num_splits
|
||||||
self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata
|
self.cg_buf_tile_scheduler_metadata[ubatch_id] = tile_scheduler_metadata
|
||||||
elif n <= self.cg_buf_num_splits.size(0):
|
elif n <= self.cg_buf_num_splits[ubatch_id].size(0):
|
||||||
assert self.cg_buf_tile_scheduler_metadata is not None
|
assert self.cg_buf_tile_scheduler_metadata[ubatch_id] is not None
|
||||||
|
|
||||||
# Metadata per-SM, fixed size (#SMs, TileMetadataSize)
|
# Metadata per-SM, fixed size (#SMs, TileMetadataSize)
|
||||||
assert (self.cg_buf_tile_scheduler_metadata.size() ==
|
assert (self.cg_buf_tile_scheduler_metadata[ubatch_id].size() ==
|
||||||
tile_scheduler_metadata.size())
|
tile_scheduler_metadata.size())
|
||||||
self.cg_buf_tile_scheduler_metadata.\
|
self.cg_buf_tile_scheduler_metadata[ubatch_id].\
|
||||||
copy_(tile_scheduler_metadata)
|
copy_(tile_scheduler_metadata)
|
||||||
tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata
|
tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata[ubatch_id]
|
||||||
|
|
||||||
# Num splits is per-batch, varying size (batch_size,)
|
# Num splits is per-batch, varying size (batch_size,)
|
||||||
n = num_splits.size(0)
|
n = num_splits.size(0)
|
||||||
# logger.info(f"N: {n} num splits {self.cg_buf_num_splits.size(0)}")
|
# logger.info(f"N: {n} num splits {self.cg_buf_num_splits.size(0)}")
|
||||||
# make sure static buffer is large enough
|
# make sure static buffer is large enough
|
||||||
assert n <= self.cg_buf_num_splits.size(0)
|
assert n <= self.cg_buf_num_splits[ubatch_id].size(0)
|
||||||
num_splits_view = self.cg_buf_num_splits[:n]
|
num_splits_view = self.cg_buf_num_splits[ubatch_id][:n]
|
||||||
num_splits_view.copy_(num_splits)
|
num_splits_view.copy_(num_splits)
|
||||||
self.cg_buf_num_splits[n:].fill_(0) # fill the rest with 0s
|
self.cg_buf_num_splits[ubatch_id][n:].fill_(0) # fill the rest with 0s
|
||||||
num_splits = num_splits_view
|
num_splits = num_splits_view
|
||||||
|
|
||||||
return FlashMLADecodeMetadata(
|
return FlashMLADecodeMetadata(
|
||||||
|
|||||||
@ -17,6 +17,7 @@ import vllm.envs as envs
|
|||||||
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||||
get_kv_connector_cache_layout)
|
get_kv_connector_cache_layout)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.v1.worker.block_table import BlockTable
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -29,6 +30,8 @@ class CommonAttentionMetadata:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
query_start_loc: torch.Tensor
|
query_start_loc: torch.Tensor
|
||||||
|
|
||||||
|
# query_start_loc_cpu: torch.Tensor
|
||||||
"""(batch_size + 1,), the start location of each request in query Tensor"""
|
"""(batch_size + 1,), the start location of each request in query Tensor"""
|
||||||
seq_lens: torch.Tensor
|
seq_lens: torch.Tensor
|
||||||
"""(batch_size,), the length of each request including both computed tokens
|
"""(batch_size,), the length of each request including both computed tokens
|
||||||
@ -41,6 +44,47 @@ class CommonAttentionMetadata:
|
|||||||
max_query_len: int
|
max_query_len: int
|
||||||
"""Longest query in batch"""
|
"""Longest query in batch"""
|
||||||
|
|
||||||
|
# block_table: BlockTable
|
||||||
|
|
||||||
|
# def compute_request_slice(self, token_slice: slice) -> slice:
|
||||||
|
# """
|
||||||
|
# return
|
||||||
|
# - num_decodes: number of decode requests
|
||||||
|
# - num_prefills: number of prefill requests
|
||||||
|
# - num_decode_tokens: number of decode tokens
|
||||||
|
# - num_prefill_tokens: number of prefill tokens
|
||||||
|
# """
|
||||||
|
# if self.max_query_len == 1:
|
||||||
|
# # Pure decode
|
||||||
|
# return token_slice
|
||||||
|
# else:
|
||||||
|
# # Find the first query_start_loc that's greater than the token_slice.start
|
||||||
|
# first_reqest = (self.query_start_loc_cpu >= token_slice.start).int().argmax(dim=-1).item()
|
||||||
|
# last_request = (self.query_start_loc_cpu < token_slice.stop).int().argmax(dim=-1).item()
|
||||||
|
# return slice(first_reqest, last_request)
|
||||||
|
|
||||||
|
# # Slice the current CommonAttentionMetatdata into two
|
||||||
|
# def _slice(self, token_slice: slice) -> CommonAttentionMetadata:
|
||||||
|
# request_slice = self.compute_request_slice(token_slice)
|
||||||
|
# query_start_loc = slice_query_start_locs(
|
||||||
|
# self.query_start_loc, request_slice)
|
||||||
|
|
||||||
|
# seq_lens = self.seq_lens[request_slice]
|
||||||
|
# num_requests = request_slice.stop - request_slice.start
|
||||||
|
# num_actual_tokens = token_slice.stop - token_slice.start
|
||||||
|
# #TODO(Sage) update this for prefill
|
||||||
|
# max_query_len = 1
|
||||||
|
|
||||||
|
# block_table = self.block_table
|
||||||
|
# block_table_tensor = block_table.get_device_tensor()[req_slice]
|
||||||
|
# block_table.slot_mapping[token_slice].copy_(
|
||||||
|
# block_table.slot_mapping_cpu[token_slice],
|
||||||
|
# non_blocking=True)
|
||||||
|
# block_table.slot_mapping[token_slice.stop:].fill_(-1)
|
||||||
|
# slot_mapping = block_table.slot_mapping[token_slice]
|
||||||
|
|
||||||
|
# pass
|
||||||
|
|
||||||
|
|
||||||
M = TypeVar("M")
|
M = TypeVar("M")
|
||||||
|
|
||||||
|
|||||||
@ -839,6 +839,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
max_query_len=max(tokens[req_slice]),
|
max_query_len=max(tokens[req_slice]),
|
||||||
common_prefix_len=common_prefix_len,
|
common_prefix_len=common_prefix_len,
|
||||||
common_attn_metadata=common_attn_metadata,
|
common_attn_metadata=common_attn_metadata,
|
||||||
|
ubatch_id=ubid
|
||||||
))
|
))
|
||||||
for layer_name in kv_cache_group_spec.layer_names:
|
for layer_name in kv_cache_group_spec.layer_names:
|
||||||
assert type(attn_metadata) is list
|
assert type(attn_metadata) is list
|
||||||
@ -1583,7 +1584,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
def _make_ubatch_contexts(ubatch_slices,
|
def _make_ubatch_contexts(ubatch_slices,
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
compute_stream,
|
compute_stream,
|
||||||
is_dummy_run,
|
|
||||||
num_tokens_across_dp,
|
num_tokens_across_dp,
|
||||||
skip_cuda_graphs) -> list[UBatchContext]:
|
skip_cuda_graphs) -> list[UBatchContext]:
|
||||||
ubatch_ctxs = make_ubatch_contexts(len(ubatch_slices),
|
ubatch_ctxs = make_ubatch_contexts(len(ubatch_slices),
|
||||||
@ -1623,7 +1623,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
ubatch_slices=ubatch_slices,
|
ubatch_slices=ubatch_slices,
|
||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
compute_stream=compute_stream,
|
compute_stream=compute_stream,
|
||||||
is_dummy_run=is_dummy_run,
|
|
||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
skip_cuda_graphs=skip_cuda_graphs
|
skip_cuda_graphs=skip_cuda_graphs
|
||||||
)
|
)
|
||||||
@ -2369,7 +2368,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# _dummy_run doesn't go through _prepare_inputs so
|
# _dummy_run doesn't go through _prepare_inputs so
|
||||||
# we synchronize with other DP ranks here
|
# we synchronize with other DP ranks here
|
||||||
# logger.info(f"NUM TOKENS {num_tokens} SHOULD UBATCH {should_ubatch}")
|
# logger.info(f"NUM TOKENS {num_tokens} SHOULD UBATCH {should_ubatch}")
|
||||||
should_ubatch = self.should_ubatch(allow_microbatching)
|
should_ubatch = self.should_ubatch(should_ubatch)
|
||||||
# Padding for DP
|
# Padding for DP
|
||||||
# logger.info("PADDING DUMMY")
|
# logger.info("PADDING DUMMY")
|
||||||
num_tokens_across_dp = None
|
num_tokens_across_dp = None
|
||||||
@ -2451,6 +2450,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
common_prefix_len=0,
|
common_prefix_len=0,
|
||||||
common_attn_metadata=common_attn_metadata,
|
common_attn_metadata=common_attn_metadata,
|
||||||
|
ubatch_id=ubid
|
||||||
))
|
))
|
||||||
for layer_name in kv_cache_group_spec.layer_names:
|
for layer_name in kv_cache_group_spec.layer_names:
|
||||||
assert type(attn_metadata) is list
|
assert type(attn_metadata) is list
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user