mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-31 23:57:11 +08:00
[Core] Separate out attention metadata building logic from prepare inputs (#26764)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
289eb6c537
commit
636efd10a5
@ -1054,7 +1054,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
def _get_encoder_seq_lens(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
scheduled_encoder_inputs: dict[str, list[int]],
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
num_reqs: int,
|
||||
) -> np.ndarray | None:
|
||||
@ -1064,31 +1064,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Build encoder_seq_lens array mapping request indices to
|
||||
# encoder lengths for inputs scheduled in this batch
|
||||
encoder_seq_lens = np.zeros(num_reqs, dtype=np.int32)
|
||||
for req_id in scheduler_output.scheduled_encoder_inputs:
|
||||
for req_id in scheduled_encoder_inputs:
|
||||
req_index = self.input_batch.req_id_to_index[req_id]
|
||||
encoder_seq_lens[req_index] = self.max_encoder_len
|
||||
|
||||
return encoder_seq_lens
|
||||
|
||||
def _prepare_inputs(
|
||||
self, scheduler_output: "SchedulerOutput"
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
max_num_scheduled_tokens: int,
|
||||
) -> tuple[
|
||||
PerLayerAttnMetadata,
|
||||
torch.Tensor,
|
||||
SpecDecodeMetadata | None,
|
||||
np.ndarray,
|
||||
CommonAttentionMetadata | None,
|
||||
int,
|
||||
UBatchSlices | None,
|
||||
torch.Tensor | None,
|
||||
bool,
|
||||
]:
|
||||
"""
|
||||
:return: tuple[
|
||||
attn_metadata: layer-to-attention_metadata mapping,
|
||||
logits_indices, spec_decode_metadata,
|
||||
num_scheduled_tokens, spec_decode_common_attn_metadata,
|
||||
max_num_scheduled_tokens, use_cascade_attn
|
||||
ubatch_slices, num_tokens_across_dp,
|
||||
]
|
||||
"""
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
@ -1100,12 +1096,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# This way, we can overlap the copy with the following CPU operations.
|
||||
self.input_batch.block_table.commit_block_table(num_reqs)
|
||||
|
||||
# Get the number of scheduled tokens for each request.
|
||||
req_ids = self.input_batch.req_ids
|
||||
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
|
||||
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
||||
max_num_scheduled_tokens = max(tokens)
|
||||
|
||||
# Get request indices.
|
||||
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
|
||||
req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens)
|
||||
@ -1232,8 +1222,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Fill unused with 0 for full cuda graph mode.
|
||||
self.seq_lens.np[num_reqs:].fill(0)
|
||||
self.seq_lens.copy_to_gpu()
|
||||
seq_lens = self.seq_lens.gpu[:num_reqs]
|
||||
max_seq_len = self.seq_lens.np[:num_reqs].max().item()
|
||||
|
||||
num_tokens = [self.requests[r].num_tokens for r in self.input_batch.req_ids]
|
||||
num_tokens_np = np.array(num_tokens, dtype=np.int32)
|
||||
@ -1305,11 +1293,46 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.num_decode_draft_tokens.np[num_reqs:].fill(-1)
|
||||
self.num_decode_draft_tokens.copy_to_gpu()
|
||||
|
||||
logits_indices_padded = None
|
||||
if self.cache_config.kv_sharing_fast_prefill:
|
||||
logits_indices_padded = self._prepare_kv_sharing_fast_prefill(
|
||||
logits_indices
|
||||
# Hot-Swap lora model
|
||||
if self.lora_config:
|
||||
assert (
|
||||
np.sum(num_sampled_tokens)
|
||||
<= self.vllm_config.scheduler_config.max_num_batched_tokens
|
||||
)
|
||||
self.set_active_loras(
|
||||
self.input_batch, num_scheduled_tokens, num_sampled_tokens
|
||||
)
|
||||
|
||||
return (
|
||||
logits_indices,
|
||||
spec_decode_metadata,
|
||||
ubatch_slices,
|
||||
num_tokens_across_dp,
|
||||
)
|
||||
|
||||
def _build_attention_metadata(
|
||||
self,
|
||||
total_num_scheduled_tokens: int,
|
||||
max_num_scheduled_tokens: int,
|
||||
num_reqs: int,
|
||||
ubatch_slices: UBatchSlices | None = None,
|
||||
logits_indices: torch.Tensor | None = None,
|
||||
use_spec_decode: bool = False,
|
||||
for_cudagraph_capture: bool = False,
|
||||
scheduled_encoder_inputs: dict[str, list[int]] | None = None,
|
||||
cascade_attn_prefix_lens: list[list[int]] | None = None,
|
||||
) -> tuple[PerLayerAttnMetadata, CommonAttentionMetadata | None]:
|
||||
"""
|
||||
:return: tuple[attn_metadata, spec_decode_common_attn_metadata]
|
||||
"""
|
||||
logits_indices_padded = None
|
||||
num_logits_indices = 0
|
||||
if logits_indices is not None:
|
||||
num_logits_indices = logits_indices.size(0)
|
||||
if self.cache_config.kv_sharing_fast_prefill:
|
||||
logits_indices_padded = self._prepare_kv_sharing_fast_prefill(
|
||||
logits_indices
|
||||
)
|
||||
|
||||
# update seq_lens of decode reqs under DCP.
|
||||
if self.dcp_world_size > 1:
|
||||
@ -1324,15 +1347,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
attn_metadata: PerLayerAttnMetadata = {}
|
||||
if ubatch_slices is not None:
|
||||
attn_metadata = [dict() for _ in range(len(ubatch_slices))]
|
||||
use_cascade_attn = False
|
||||
|
||||
# Used in the below loop.
|
||||
# Used in the below loop
|
||||
query_start_loc = self.query_start_loc.gpu[: num_reqs + 1]
|
||||
query_start_loc_cpu = self.query_start_loc.cpu[: num_reqs + 1]
|
||||
seq_lens = self.seq_lens.gpu[:num_reqs]
|
||||
seq_lens_cpu = self.seq_lens.cpu[:num_reqs]
|
||||
num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[
|
||||
:num_reqs
|
||||
]
|
||||
dcp_local_seq_lens = (
|
||||
self.dcp_local_seq_lens.gpu[:num_reqs] if self.dcp_world_size > 1 else None
|
||||
)
|
||||
spec_decode_common_attn_metadata = None
|
||||
|
||||
if for_cudagraph_capture:
|
||||
# For some attention backends (e.g. FA) with sliding window models we need
|
||||
# to make sure the backend see a max_seq_len that is larger to the sliding
|
||||
# window size when capturing to make sure the correct kernel is selected.
|
||||
max_seq_len = self.max_model_len
|
||||
else:
|
||||
max_seq_len = self.seq_lens.np[:num_reqs].max().item()
|
||||
|
||||
if use_spec_decode:
|
||||
self.num_accepted_tokens.np[:num_reqs] = (
|
||||
self.input_batch.num_accepted_tokens_cpu[:num_reqs]
|
||||
@ -1342,14 +1378,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
# Prepare the attention metadata for each KV cache group and make layers
|
||||
# in the same group share the same metadata.
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
for kv_cache_gid, kv_cache_group in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups
|
||||
):
|
||||
encoder_seq_lens = self._get_encoder_seq_lens(
|
||||
scheduler_output, kv_cache_group_spec.kv_cache_spec, num_reqs
|
||||
scheduled_encoder_inputs or {},
|
||||
kv_cache_group.kv_cache_spec,
|
||||
num_reqs,
|
||||
)
|
||||
|
||||
if isinstance(kv_cache_group_spec.kv_cache_spec, EncoderOnlyAttentionSpec):
|
||||
if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec):
|
||||
# Encoder-only layers do not have KV cache, so we need to
|
||||
# create a dummy block table and slot mapping for them.
|
||||
blk_table_tensor = torch.zeros(
|
||||
@ -1362,18 +1400,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
)
|
||||
num_common_prefix_blocks = 0
|
||||
else:
|
||||
blk_table = self.input_batch.block_table[kv_cache_group_id]
|
||||
blk_table = self.input_batch.block_table[kv_cache_gid]
|
||||
blk_table_tensor = blk_table.get_device_tensor(num_reqs)
|
||||
slot_mapping = blk_table.slot_mapping.gpu[:total_num_scheduled_tokens]
|
||||
|
||||
# Fill unused with -1. Needed for reshape_and_cache in full cuda
|
||||
# graph mode.
|
||||
blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(-1)
|
||||
num_common_prefix_blocks = scheduler_output.num_common_prefix_blocks[
|
||||
kv_cache_group_id
|
||||
]
|
||||
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc,
|
||||
@ -1388,35 +1422,26 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
block_table_tensor=blk_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
logits_indices_padded=logits_indices_padded,
|
||||
num_logits_indices=logits_indices.size(0),
|
||||
num_logits_indices=num_logits_indices,
|
||||
causal=True,
|
||||
encoder_seq_lens=encoder_seq_lens,
|
||||
dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs]
|
||||
if self.dcp_world_size > 1
|
||||
else None,
|
||||
dcp_local_seq_lens=dcp_local_seq_lens,
|
||||
)
|
||||
|
||||
if self.speculative_config and spec_decode_common_attn_metadata is None:
|
||||
if isinstance(self.drafter, EagleProposer):
|
||||
if (
|
||||
self.drafter.attn_layer_names[0]
|
||||
in kv_cache_group_spec.layer_names
|
||||
):
|
||||
if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names:
|
||||
spec_decode_common_attn_metadata = common_attn_metadata
|
||||
else:
|
||||
spec_decode_common_attn_metadata = common_attn_metadata
|
||||
|
||||
for attn_group in self.attn_groups[kv_cache_group_id]:
|
||||
# Prepare for cascade attention if enabled & beneficial.
|
||||
common_prefix_len = 0
|
||||
for attn_gid, attn_group in enumerate(self.attn_groups[kv_cache_gid]):
|
||||
cascade_attn_prefix_len = (
|
||||
cascade_attn_prefix_lens[kv_cache_gid][attn_gid]
|
||||
if cascade_attn_prefix_lens
|
||||
else 0
|
||||
)
|
||||
builder = attn_group.get_metadata_builder()
|
||||
if self.cascade_attn_enabled:
|
||||
common_prefix_len = self._compute_cascade_attn_prefix_len(
|
||||
num_scheduled_tokens,
|
||||
num_common_prefix_blocks,
|
||||
attn_group.kv_cache_spec,
|
||||
builder,
|
||||
)
|
||||
|
||||
extra_attn_metadata_args = {}
|
||||
if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder):
|
||||
@ -1434,51 +1459,69 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
for ubid, common_attn_metadata in enumerate(
|
||||
common_attn_metadata_list
|
||||
):
|
||||
attn_metadata_i = attn_group.get_metadata_builder(
|
||||
ubatch_id=ubid
|
||||
).build(
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
builder = attn_group.get_metadata_builder(ubatch_id=ubid)
|
||||
if for_cudagraph_capture:
|
||||
attn_metadata_i = builder.build_for_cudagraph_capture(
|
||||
common_attn_metadata
|
||||
)
|
||||
else:
|
||||
attn_metadata_i = builder.build(
|
||||
common_prefix_len=cascade_attn_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
for layer_name in kv_cache_group.layer_names:
|
||||
assert type(attn_metadata) is list
|
||||
attn_metadata[ubid][layer_name] = attn_metadata_i
|
||||
else:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata_i = builder.build(
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
**extra_attn_metadata_args,
|
||||
)
|
||||
use_cascade_attn |= getattr(attn_metadata_i, "use_cascade", False)
|
||||
if for_cudagraph_capture:
|
||||
attn_metadata_i = builder.build_for_cudagraph_capture(
|
||||
common_attn_metadata
|
||||
)
|
||||
else:
|
||||
attn_metadata_i = builder.build(
|
||||
common_prefix_len=cascade_attn_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
**extra_attn_metadata_args,
|
||||
)
|
||||
for layer_name in attn_group.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
# disable cascade attention when DBO
|
||||
if ubatch_slices is not None:
|
||||
use_cascade_attn = False
|
||||
return attn_metadata, spec_decode_common_attn_metadata
|
||||
|
||||
# Hot-Swap lora model
|
||||
if self.lora_config:
|
||||
assert (
|
||||
np.sum(num_sampled_tokens)
|
||||
<= self.vllm_config.scheduler_config.max_num_batched_tokens
|
||||
)
|
||||
self.set_active_loras(
|
||||
self.input_batch, num_scheduled_tokens, num_sampled_tokens
|
||||
)
|
||||
def _compute_cascade_attn_prefix_lens(
|
||||
self,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
num_common_prefix_blocks: list[int],
|
||||
) -> list[list[int]] | None:
|
||||
"""
|
||||
:return: Optional[cascade_attn_prefix_lens]
|
||||
cascade_attn_prefix_lens is 2D: ``[kv_cache_group_id][attn_group_idx]``,
|
||||
None if we should not use cascade attention
|
||||
"""
|
||||
|
||||
return (
|
||||
attn_metadata,
|
||||
logits_indices,
|
||||
spec_decode_metadata,
|
||||
num_scheduled_tokens,
|
||||
spec_decode_common_attn_metadata,
|
||||
max_num_scheduled_tokens,
|
||||
ubatch_slices,
|
||||
num_tokens_across_dp,
|
||||
use_cascade_attn,
|
||||
)
|
||||
use_cascade_attn = False
|
||||
num_kv_cache_groups = len(self.kv_cache_config.kv_cache_groups)
|
||||
cascade_attn_prefix_lens: list[list[int]] = [
|
||||
[] for _ in range(num_kv_cache_groups)
|
||||
]
|
||||
|
||||
for kv_cache_gid in range(num_kv_cache_groups):
|
||||
for attn_group in self.attn_groups[kv_cache_gid]:
|
||||
if isinstance(attn_group.kv_cache_spec, EncoderOnlyAttentionSpec):
|
||||
cascade_attn_prefix_len = 0
|
||||
else:
|
||||
# 0 if cascade attention should not be used
|
||||
cascade_attn_prefix_len = self._compute_cascade_attn_prefix_len(
|
||||
num_scheduled_tokens,
|
||||
num_common_prefix_blocks[kv_cache_gid],
|
||||
attn_group.kv_cache_spec,
|
||||
attn_group.get_metadata_builder(),
|
||||
)
|
||||
cascade_attn_prefix_lens[kv_cache_gid].append(cascade_attn_prefix_len)
|
||||
use_cascade_attn |= cascade_attn_prefix_len > 0
|
||||
|
||||
return cascade_attn_prefix_lens if use_cascade_attn else None
|
||||
|
||||
def _compute_cascade_attn_prefix_len(
|
||||
self,
|
||||
@ -1504,6 +1547,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
Returns:
|
||||
int: Length of common prefix in tokens.
|
||||
"""
|
||||
|
||||
common_prefix_len = num_common_prefix_blocks * kv_cache_spec.block_size
|
||||
if common_prefix_len == 0:
|
||||
# Common case.
|
||||
@ -2497,18 +2541,48 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
"it when the requests need prompt logprobs"
|
||||
)
|
||||
|
||||
# Prepare the decoder inputs.
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
req_ids = self.input_batch.req_ids
|
||||
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
|
||||
num_scheduled_tokens_np = np.array(tokens, dtype=np.int32)
|
||||
max_num_scheduled_tokens = int(num_scheduled_tokens_np.max())
|
||||
|
||||
(
|
||||
attn_metadata,
|
||||
logits_indices,
|
||||
spec_decode_metadata,
|
||||
num_scheduled_tokens_np,
|
||||
spec_decode_common_attn_metadata,
|
||||
max_query_len,
|
||||
ubatch_slices,
|
||||
num_tokens_across_dp,
|
||||
use_cascade_attn,
|
||||
) = self._prepare_inputs(scheduler_output)
|
||||
) = self._prepare_inputs(
|
||||
scheduler_output, num_scheduled_tokens_np, max_num_scheduled_tokens
|
||||
)
|
||||
|
||||
cascade_attn_prefix_lens = None
|
||||
# Disable cascade attention when using microbatching (DBO)
|
||||
if self.cascade_attn_enabled and ubatch_slices is None:
|
||||
# Pre-compute cascade attention prefix lengths
|
||||
# NOTE: Must be AFTER _prepare_inputs uses self.input_batch state
|
||||
cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens(
|
||||
num_scheduled_tokens_np,
|
||||
scheduler_output.num_common_prefix_blocks,
|
||||
)
|
||||
|
||||
# TODO(lucas): move cudagraph dispatching here:
|
||||
# https://github.com/vllm-project/vllm/issues/23789
|
||||
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||
attn_metadata, spec_decode_common_attn_metadata = (
|
||||
self._build_attention_metadata(
|
||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||
max_num_scheduled_tokens=max_num_scheduled_tokens,
|
||||
num_reqs=num_reqs,
|
||||
ubatch_slices=ubatch_slices,
|
||||
logits_indices=logits_indices,
|
||||
use_spec_decode=use_spec_decode,
|
||||
scheduled_encoder_inputs=scheduler_output.scheduled_encoder_inputs,
|
||||
cascade_attn_prefix_lens=cascade_attn_prefix_lens,
|
||||
)
|
||||
)
|
||||
|
||||
dp_rank = self.parallel_config.data_parallel_rank
|
||||
if ubatch_slices:
|
||||
@ -2532,16 +2606,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
scheduler_output, num_input_tokens, intermediate_tensors
|
||||
)
|
||||
|
||||
uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
|
||||
num_scheduled_tokens == self.input_batch.num_reqs * max_query_len
|
||||
)
|
||||
uniform_decode = (
|
||||
max_num_scheduled_tokens == self.uniform_decode_query_len
|
||||
) and (num_scheduled_tokens == num_reqs * max_num_scheduled_tokens)
|
||||
batch_descriptor = BatchDescriptor(
|
||||
num_tokens=num_input_tokens,
|
||||
uniform_decode=uniform_decode,
|
||||
has_lora=len(self.input_batch.lora_id_to_lora_request) > 0,
|
||||
)
|
||||
cudagraph_runtime_mode, batch_descriptor = (
|
||||
self.cudagraph_dispatcher.dispatch(batch_descriptor, use_cascade_attn)
|
||||
self.cudagraph_dispatcher.dispatch(
|
||||
batch_descriptor,
|
||||
use_cascade_attn=cascade_attn_prefix_lens is not None,
|
||||
)
|
||||
)
|
||||
|
||||
# Set cudagraph mode to none if calc_kv_scales is true.
|
||||
@ -3437,10 +3514,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# If force_attention is True, we always capture attention. Otherwise,
|
||||
# it only happens for cudagraph_runtime_mode=FULL.
|
||||
if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
attn_metadata = {}
|
||||
if ubatch_slices is not None:
|
||||
attn_metadata = [dict() for _ in range(len(ubatch_slices))]
|
||||
|
||||
if create_mixed_batch:
|
||||
# In the mixed batch mode (used for FI warmup), we use
|
||||
# shorter sequence lengths to run faster.
|
||||
@ -3456,55 +3529,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens
|
||||
self.query_start_loc.copy_to_gpu()
|
||||
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups
|
||||
):
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=self.query_start_loc.gpu[: num_reqs + 1],
|
||||
query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs + 1],
|
||||
seq_lens=self.seq_lens.gpu[:num_reqs],
|
||||
seq_lens_cpu=self.seq_lens.cpu[:num_reqs],
|
||||
num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[
|
||||
:num_reqs
|
||||
],
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=self.max_model_len,
|
||||
block_table_tensor=self.input_batch.block_table[
|
||||
kv_cache_group_id
|
||||
].get_device_tensor(num_reqs),
|
||||
slot_mapping=self.input_batch.block_table[
|
||||
kv_cache_group_id
|
||||
].slot_mapping.gpu[:num_tokens],
|
||||
causal=True,
|
||||
dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs]
|
||||
if self.dcp_world_size > 1
|
||||
else None,
|
||||
)
|
||||
for attn_group in self.attn_groups[kv_cache_group_id]:
|
||||
if ubatch_slices is not None:
|
||||
common_attn_metadata_list = split_attn_metadata(
|
||||
ubatch_slices, common_attn_metadata
|
||||
)
|
||||
for ubid, common_attn_metadata in enumerate(
|
||||
common_attn_metadata_list
|
||||
):
|
||||
assert common_attn_metadata.max_query_len == 1
|
||||
attn_metadata_i = attn_group.get_metadata_builder(
|
||||
ubatch_id=ubid
|
||||
).build_for_cudagraph_capture(common_attn_metadata)
|
||||
for layer_name in attn_group.layer_names:
|
||||
assert type(attn_metadata) is list
|
||||
attn_metadata[ubid][layer_name] = attn_metadata_i
|
||||
else:
|
||||
assert type(attn_metadata) is dict
|
||||
metadata_builder = attn_group.get_metadata_builder()
|
||||
attn_metadata_i = metadata_builder.build_for_cudagraph_capture(
|
||||
common_attn_metadata
|
||||
)
|
||||
for layer_name in attn_group.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
attn_metadata, _ = self._build_attention_metadata(
|
||||
total_num_scheduled_tokens=num_tokens,
|
||||
max_num_scheduled_tokens=max_query_len,
|
||||
num_reqs=num_reqs,
|
||||
ubatch_slices=ubatch_slices,
|
||||
for_cudagraph_capture=True,
|
||||
)
|
||||
|
||||
with self.maybe_dummy_run_with_lora(
|
||||
self.lora_config,
|
||||
@ -4478,9 +4509,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
list[int]: List of kernel block sizes for each cache group.
|
||||
"""
|
||||
kernel_block_sizes = []
|
||||
for kv_cache_group_id, kv_cache_group in enumerate(
|
||||
kv_cache_config.kv_cache_groups
|
||||
):
|
||||
for kv_cache_gid, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
|
||||
kv_cache_spec = kv_cache_group.kv_cache_spec
|
||||
if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs):
|
||||
# All layers in the UniformTypeKVCacheSpecs have the same type,
|
||||
@ -4492,7 +4521,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# This is an attention backend that supports virtual
|
||||
# block splitting. Get the supported block sizes from
|
||||
# all backends in the group.
|
||||
attn_groups = self.attn_groups[kv_cache_group_id]
|
||||
attn_groups = self.attn_groups[kv_cache_gid]
|
||||
kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size
|
||||
selected_kernel_size = self.select_common_block_size(
|
||||
kv_manager_block_size, attn_groups
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user