mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 20:57:08 +08:00
wip
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
980a172474
commit
8293182c8c
@ -31,6 +31,8 @@ if current_platform.is_cuda():
|
||||
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
||||
get_scheduler_metadata)
|
||||
|
||||
from vllm.v1.attention.backends.utils import slice_query_start_locs
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -322,23 +324,23 @@ class FlashAttentionMetadataBuilder:
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
return False
|
||||
|
||||
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata):
|
||||
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
def build_slice(self, max_query_len: int, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
req_slice: slice,
|
||||
token_slice: slice) -> FlashAttentionMetadata:
|
||||
num_reqs = req_slice.stop - req_slice.start
|
||||
num_tokens = token_slice.stop - token_slice.start
|
||||
|
||||
max_seq_len = self.runner.seq_lens_np[req_slice].max()
|
||||
query_start_loc = slice_query_start_locs(
|
||||
common_attn_metadata.query_start_loc, req_slice)
|
||||
seq_lens = common_attn_metadata.seq_lens[req_slice]
|
||||
block_table = self.block_table
|
||||
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
|
||||
block_table_tensor = block_table.get_device_tensor()[req_slice]
|
||||
|
||||
block_table.slot_mapping[:num_actual_tokens].copy_(
|
||||
block_table.slot_mapping_cpu[:num_actual_tokens],
|
||||
non_blocking=True)
|
||||
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
|
||||
# mode.
|
||||
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
|
||||
|
||||
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
|
||||
block_table.slot_mapping[token_slice].copy_(
|
||||
block_table.slot_mapping_cpu[token_slice], non_blocking=True)
|
||||
slot_mapping = block_table.slot_mapping[token_slice]
|
||||
|
||||
if self.aot_sliding_window is None:
|
||||
self.aot_sliding_window = (-1, -1)
|
||||
@ -380,8 +382,8 @@ class FlashAttentionMetadataBuilder:
|
||||
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
|
||||
virt_block_table_tensor = make_local_attention_virtual_batches(
|
||||
self.runner.attention_chunk_size,
|
||||
self.runner.query_start_loc_np[:num_reqs + 1],
|
||||
self.runner.seq_lens_np[:num_reqs],
|
||||
query_start_loc,
|
||||
seq_lens,
|
||||
block_table_tensor,
|
||||
self.block_size,
|
||||
)
|
||||
@ -411,20 +413,20 @@ class FlashAttentionMetadataBuilder:
|
||||
use_cascade = common_prefix_len > 0
|
||||
|
||||
if use_cascade:
|
||||
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
|
||||
cu_prefix_query_lens = torch.tensor([0, num_tokens],
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
prefix_kv_lens = torch.tensor([common_prefix_len],
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] -
|
||||
suffix_kv_lens = (self.runner.seq_lens_np[req_slice] -
|
||||
common_prefix_len)
|
||||
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
|
||||
self.runner.device)
|
||||
prefix_scheduler_metadata = schedule(
|
||||
batch_size=1,
|
||||
cu_query_lens=cu_prefix_query_lens,
|
||||
max_query_len=num_actual_tokens,
|
||||
max_query_len=num_tokens,
|
||||
seqlens=prefix_kv_lens,
|
||||
max_seq_len=common_prefix_len,
|
||||
causal=False)
|
||||
@ -448,7 +450,7 @@ class FlashAttentionMetadataBuilder:
|
||||
causal=True)
|
||||
|
||||
attn_metadata = FlashAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
@ -466,6 +468,17 @@ class FlashAttentionMetadataBuilder:
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata):
|
||||
return self.build_slice(
|
||||
max_query_len=max_query_len,
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
req_slice=slice(0, num_reqs),
|
||||
token_slice=slice(0, num_actual_tokens),
|
||||
)
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
return use_cascade_attention(*args, **kwargs)
|
||||
|
||||
|
||||
@ -16,3 +16,11 @@ class CommonAttentionMetadata:
|
||||
seq_lens: torch.Tensor
|
||||
"""(batch_size,), the length of each request including both computed tokens
|
||||
and newly scheduled tokens"""
|
||||
|
||||
|
||||
def slice_query_start_locs(
|
||||
query_start_loc: torch.Tensor,
|
||||
req_slice: slice,
|
||||
) -> torch.Tensor:
|
||||
return query_start_loc[req_slice.start: req_slice.stop + 1] -\
|
||||
query_start_loc[req_slice.start]
|
||||
@ -4,7 +4,7 @@ import copy
|
||||
import gc
|
||||
import time
|
||||
import weakref
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
from typing import TYPE_CHECKING, Optional, TypeAlias, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -69,6 +69,11 @@ else:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
AttnMetadataDict: TypeAlias = dict[str, FlashAttentionMetadata]
|
||||
# list when ubatching is enabled
|
||||
PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict],
|
||||
AttnMetadataDict]
|
||||
|
||||
|
||||
class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
@ -491,7 +496,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
def _prepare_inputs(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> tuple[dict[str, FlashAttentionMetadata], torch.Tensor,
|
||||
ubatch_slices: Optional[list[tuple[
|
||||
slice, slice]]] = None, # req_slice, token_slice
|
||||
) -> tuple[PerLayerAttnMetadata, torch.Tensor,
|
||||
Optional[SpecDecodeMetadata]]:
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens > 0
|
||||
@ -612,7 +619,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc, seq_lens=seq_lens)
|
||||
|
||||
attn_metadata: dict[str, FlashAttentionMetadata] = {}
|
||||
attn_metadata: PerLayerAttnMetadata = {}
|
||||
if ubatch_slices is not None:
|
||||
attn_metadata = [dict() for _ in range(len(ubatch_slices))]
|
||||
|
||||
# Prepare the attention metadata for each KV cache group and make layers
|
||||
# in the same group share the same metadata.
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
@ -629,15 +639,37 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.attn_metadata_builders[kv_cache_group_id],
|
||||
)
|
||||
|
||||
attn_metadata_i = (
|
||||
self.attn_metadata_builders[kv_cache_group_id].build(
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata))
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
# Fill unused with -1. Needed for reshape_and_cache in full cuda
|
||||
# graph mode.
|
||||
if self.vllm_config.compilation_config.full_cuda_graph:
|
||||
self.input_batch.block_table[kv_cache_group_id]\
|
||||
.slot_mapping.fill_(-1)
|
||||
|
||||
if ubatch_slices is not None:
|
||||
for ubid, (req_slice, token_slice) in enumerate(ubatch_slices):
|
||||
attn_metadata_i = (
|
||||
self.attn_metadata_builders[kv_cache_group_id].
|
||||
build_slice(
|
||||
max(tokens[req_slice]),
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
req_slice=req_slice,
|
||||
token_slice=token_slice,
|
||||
))
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
assert type(attn_metadata) is list
|
||||
attn_metadata[ubid][layer_name] = attn_metadata_i
|
||||
else:
|
||||
attn_metadata_i = (
|
||||
self.attn_metadata_builders[kv_cache_group_id].build(
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata))
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
assert type(attn_metadata) is dict
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
use_spec_decode = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||
@ -1057,10 +1089,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
|
||||
def sync_and_slice_intermediate_tensors(
|
||||
self, num_tokens: int, intermediate_tensors: IntermediateTensors,
|
||||
self, tokens_slice: slice,
|
||||
intermediate_tensors: IntermediateTensors,
|
||||
sync_self: bool) -> IntermediateTensors:
|
||||
|
||||
assert self.intermediate_tensors is not None
|
||||
num_tokens = tokens_slice.stop - tokens_slice.start
|
||||
|
||||
tp = self.vllm_config.parallel_config.tensor_parallel_size
|
||||
enabled_sp = self.vllm_config.compilation_config.pass_config. \
|
||||
@ -1072,49 +1106,37 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
is_residual_scattered = tp > 1 and enabled_sp \
|
||||
and num_tokens % tp == 0
|
||||
|
||||
def copy_slice(is_scattered: bool) -> slice:
|
||||
if is_scattered:
|
||||
return slice(tokens_slice.start // tp, tokens_slice.stop // tp)
|
||||
else:
|
||||
return tokens_slice
|
||||
|
||||
# When sequence parallelism is enabled, the "residual" tensor is sharded
|
||||
# across tensor parallel ranks, so each rank only needs its own slice.
|
||||
if sync_self:
|
||||
assert intermediate_tensors is not None
|
||||
for k, v in intermediate_tensors.items():
|
||||
is_scattered = "residual" and is_residual_scattered
|
||||
copy_len = num_tokens // tp if is_scattered else \
|
||||
num_tokens
|
||||
self.intermediate_tensors[k][:copy_len].copy_(
|
||||
v[:copy_len], non_blocking=True)
|
||||
_copy_slice = copy_slice(is_scattered)
|
||||
self.intermediate_tensors[k][_copy_slice].copy_(
|
||||
v[_copy_slice], non_blocking=True)
|
||||
|
||||
return IntermediateTensors({
|
||||
k:
|
||||
v[:num_tokens // tp]
|
||||
if k == "residual" and is_residual_scattered else v[:num_tokens]
|
||||
v[copy_slice(k == "residual" and is_residual_scattered)]
|
||||
for k, v in self.intermediate_tensors.items()
|
||||
})
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> Union[ModelRunnerOutput, IntermediateTensors]:
|
||||
|
||||
self._update_states(scheduler_output)
|
||||
if not scheduler_output.total_num_scheduled_tokens:
|
||||
if not has_kv_transfer_group():
|
||||
# Return empty ModelRunnerOutput if there's no work to do.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
return self.kv_connector_no_forward(scheduler_output)
|
||||
|
||||
# Prepare the decoder inputs.
|
||||
attn_metadata, logits_indices, spec_decode_metadata = (
|
||||
self._prepare_inputs(scheduler_output))
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
def _get_model_inputs(self, tokens_slice: slice,
|
||||
scheduler_output: "SchedulerOutput"):
|
||||
num_tokens = tokens_slice.stop - tokens_slice.start
|
||||
if (self.use_cuda_graph
|
||||
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
|
||||
and num_tokens <= self.cudagraph_batch_sizes[-1]):
|
||||
# Use piecewise CUDA graphs.
|
||||
# Add padding to the batch size.
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
||||
num_scheduled_tokens)
|
||||
tokens_slice = \
|
||||
slice(tokens_slice.start, self.vllm_config.pad_for_cudagraph(num_tokens))
|
||||
else:
|
||||
# Eager mode.
|
||||
# Pad tokens to multiple of tensor_parallel_size when
|
||||
@ -1123,9 +1145,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.vllm_config.compilation_config.pass_config. \
|
||||
enable_sequence_parallelism and tp_size > 1:
|
||||
from vllm.utils import round_up
|
||||
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
|
||||
else:
|
||||
num_input_tokens = num_scheduled_tokens
|
||||
tokens_slice = slice(tokens_slice.start,
|
||||
round_up(num_tokens, tp_size))
|
||||
num_tokens = tokens_slice.stop - tokens_slice.start
|
||||
|
||||
# _prepare_inputs may reorder the batch, so we must gather multi
|
||||
# modal outputs after that to ensure the correct order
|
||||
@ -1140,51 +1162,94 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
||||
# embeddings), we always use embeddings (rather than token ids)
|
||||
# as input to the multimodal model, even when the input is text.
|
||||
input_ids = self.input_ids[:num_scheduled_tokens]
|
||||
input_ids = self.input_ids[tokens_slice]
|
||||
if mm_embeds:
|
||||
inputs_embeds = self.model.get_input_embeddings(
|
||||
input_ids, mm_embeds)
|
||||
else:
|
||||
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
||||
# TODO(woosuk): Avoid the copy. Optimize.
|
||||
self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
|
||||
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
||||
self.inputs_embeds[tokens_slice].copy_(inputs_embeds)
|
||||
inputs_embeds = self.inputs_embeds[tokens_slice]
|
||||
input_ids = None
|
||||
else:
|
||||
# For text-only models, we use token ids as input.
|
||||
# While it is possible to use embeddings as input just like the
|
||||
# multimodal models, it is not desirable for performance since
|
||||
# then the embedding layer is not included in the CUDA graph.
|
||||
input_ids = self.input_ids[:num_input_tokens]
|
||||
input_ids = self.input_ids[tokens_slice]
|
||||
inputs_embeds = None
|
||||
if self.uses_mrope:
|
||||
positions = self.mrope_positions[:, :num_input_tokens]
|
||||
positions = self.mrope_positions[:, tokens_slice]
|
||||
else:
|
||||
positions = self.positions[:num_input_tokens]
|
||||
positions = self.positions[tokens_slice]
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
intermediate_tensors = None
|
||||
else:
|
||||
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
|
||||
num_input_tokens, intermediate_tensors, True)
|
||||
tokens_slice, intermediate_tensors, True)
|
||||
return input_ids, positions, inputs_embeds, intermediate_tensors
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> Union[ModelRunnerOutput, IntermediateTensors]:
|
||||
ubatch_slices: Optional[list[tuple[slice, slice]]] = None
|
||||
|
||||
self._update_states(scheduler_output)
|
||||
if not scheduler_output.total_num_scheduled_tokens:
|
||||
if not has_kv_transfer_group():
|
||||
# Return empty ModelRunnerOutput if there's no work to do.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
return self.kv_connector_no_forward(scheduler_output)
|
||||
|
||||
# Prepare the decoder inputs.
|
||||
attn_metadata, logits_indices, spec_decode_metadata = (
|
||||
self._prepare_inputs(scheduler_output))
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
|
||||
# Run the decoder.
|
||||
# Use persistent buffers for CUDA graphs.
|
||||
with set_forward_context(attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens):
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
|
||||
model_output = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
if ubatch_slices is not None:
|
||||
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||
self._get_model_inputs(tokens_slice, scheduler_output)
|
||||
num_input_token = tokens_slice.stop - tokens_slice.start
|
||||
|
||||
self.maybe_wait_for_kv_save()
|
||||
finished_sending, finished_recving = (
|
||||
self.get_finished_kv_transfers(scheduler_output))
|
||||
with set_forward_context(attn_metadata[i],
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens):
|
||||
|
||||
model_output = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
else:
|
||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||
self._get_model_inputs(slice(0, num_scheduled_tokens),
|
||||
scheduler_output)
|
||||
|
||||
with set_forward_context(attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_scheduled_tokens):
|
||||
model_output = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
self.maybe_wait_for_kv_save()
|
||||
finished_sending, finished_recving = (
|
||||
self.get_finished_kv_transfers(scheduler_output))
|
||||
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
|
||||
5
vllm/v1/worker/ubatching.py
Normal file
5
vllm/v1/worker/ubatching.py
Normal file
@ -0,0 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
class UBatchContext:
|
||||
|
||||
def __init__(self, ubatch_id: int):
|
||||
self.ubatch_id = ubatch_id
|
||||
Loading…
x
Reference in New Issue
Block a user