mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-05 07:29:07 +08:00
wip
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
980a172474
commit
8293182c8c
@ -31,6 +31,8 @@ if current_platform.is_cuda():
|
|||||||
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
||||||
get_scheduler_metadata)
|
get_scheduler_metadata)
|
||||||
|
|
||||||
|
from vllm.v1.attention.backends.utils import slice_query_start_locs
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -322,23 +324,23 @@ class FlashAttentionMetadataBuilder:
|
|||||||
scheduler_output: "SchedulerOutput") -> bool:
|
scheduler_output: "SchedulerOutput") -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
def build_slice(self, max_query_len: int, common_prefix_len: int,
|
||||||
common_prefix_len: int,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
common_attn_metadata: CommonAttentionMetadata):
|
req_slice: slice,
|
||||||
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
|
token_slice: slice) -> FlashAttentionMetadata:
|
||||||
query_start_loc = common_attn_metadata.query_start_loc
|
num_reqs = req_slice.stop - req_slice.start
|
||||||
seq_lens = common_attn_metadata.seq_lens
|
num_tokens = token_slice.stop - token_slice.start
|
||||||
|
|
||||||
|
max_seq_len = self.runner.seq_lens_np[req_slice].max()
|
||||||
|
query_start_loc = slice_query_start_locs(
|
||||||
|
common_attn_metadata.query_start_loc, req_slice)
|
||||||
|
seq_lens = common_attn_metadata.seq_lens[req_slice]
|
||||||
block_table = self.block_table
|
block_table = self.block_table
|
||||||
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
|
block_table_tensor = block_table.get_device_tensor()[req_slice]
|
||||||
|
|
||||||
block_table.slot_mapping[:num_actual_tokens].copy_(
|
block_table.slot_mapping[token_slice].copy_(
|
||||||
block_table.slot_mapping_cpu[:num_actual_tokens],
|
block_table.slot_mapping_cpu[token_slice], non_blocking=True)
|
||||||
non_blocking=True)
|
slot_mapping = block_table.slot_mapping[token_slice]
|
||||||
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
|
|
||||||
# mode.
|
|
||||||
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
|
|
||||||
|
|
||||||
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
|
|
||||||
|
|
||||||
if self.aot_sliding_window is None:
|
if self.aot_sliding_window is None:
|
||||||
self.aot_sliding_window = (-1, -1)
|
self.aot_sliding_window = (-1, -1)
|
||||||
@ -380,8 +382,8 @@ class FlashAttentionMetadataBuilder:
|
|||||||
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
|
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
|
||||||
virt_block_table_tensor = make_local_attention_virtual_batches(
|
virt_block_table_tensor = make_local_attention_virtual_batches(
|
||||||
self.runner.attention_chunk_size,
|
self.runner.attention_chunk_size,
|
||||||
self.runner.query_start_loc_np[:num_reqs + 1],
|
query_start_loc,
|
||||||
self.runner.seq_lens_np[:num_reqs],
|
seq_lens,
|
||||||
block_table_tensor,
|
block_table_tensor,
|
||||||
self.block_size,
|
self.block_size,
|
||||||
)
|
)
|
||||||
@ -411,20 +413,20 @@ class FlashAttentionMetadataBuilder:
|
|||||||
use_cascade = common_prefix_len > 0
|
use_cascade = common_prefix_len > 0
|
||||||
|
|
||||||
if use_cascade:
|
if use_cascade:
|
||||||
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
|
cu_prefix_query_lens = torch.tensor([0, num_tokens],
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.runner.device)
|
device=self.runner.device)
|
||||||
prefix_kv_lens = torch.tensor([common_prefix_len],
|
prefix_kv_lens = torch.tensor([common_prefix_len],
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.runner.device)
|
device=self.runner.device)
|
||||||
suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] -
|
suffix_kv_lens = (self.runner.seq_lens_np[req_slice] -
|
||||||
common_prefix_len)
|
common_prefix_len)
|
||||||
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
|
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
|
||||||
self.runner.device)
|
self.runner.device)
|
||||||
prefix_scheduler_metadata = schedule(
|
prefix_scheduler_metadata = schedule(
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
cu_query_lens=cu_prefix_query_lens,
|
cu_query_lens=cu_prefix_query_lens,
|
||||||
max_query_len=num_actual_tokens,
|
max_query_len=num_tokens,
|
||||||
seqlens=prefix_kv_lens,
|
seqlens=prefix_kv_lens,
|
||||||
max_seq_len=common_prefix_len,
|
max_seq_len=common_prefix_len,
|
||||||
causal=False)
|
causal=False)
|
||||||
@ -448,7 +450,7 @@ class FlashAttentionMetadataBuilder:
|
|||||||
causal=True)
|
causal=True)
|
||||||
|
|
||||||
attn_metadata = FlashAttentionMetadata(
|
attn_metadata = FlashAttentionMetadata(
|
||||||
num_actual_tokens=num_actual_tokens,
|
num_actual_tokens=num_tokens,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
@ -466,6 +468,17 @@ class FlashAttentionMetadataBuilder:
|
|||||||
)
|
)
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|
||||||
|
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
||||||
|
common_prefix_len: int,
|
||||||
|
common_attn_metadata: CommonAttentionMetadata):
|
||||||
|
return self.build_slice(
|
||||||
|
max_query_len=max_query_len,
|
||||||
|
common_prefix_len=common_prefix_len,
|
||||||
|
common_attn_metadata=common_attn_metadata,
|
||||||
|
req_slice=slice(0, num_reqs),
|
||||||
|
token_slice=slice(0, num_actual_tokens),
|
||||||
|
)
|
||||||
|
|
||||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||||
return use_cascade_attention(*args, **kwargs)
|
return use_cascade_attention(*args, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@ -16,3 +16,11 @@ class CommonAttentionMetadata:
|
|||||||
seq_lens: torch.Tensor
|
seq_lens: torch.Tensor
|
||||||
"""(batch_size,), the length of each request including both computed tokens
|
"""(batch_size,), the length of each request including both computed tokens
|
||||||
and newly scheduled tokens"""
|
and newly scheduled tokens"""
|
||||||
|
|
||||||
|
|
||||||
|
def slice_query_start_locs(
|
||||||
|
query_start_loc: torch.Tensor,
|
||||||
|
req_slice: slice,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return query_start_loc[req_slice.start: req_slice.stop + 1] -\
|
||||||
|
query_start_loc[req_slice.start]
|
||||||
@ -4,7 +4,7 @@ import copy
|
|||||||
import gc
|
import gc
|
||||||
import time
|
import time
|
||||||
import weakref
|
import weakref
|
||||||
from typing import TYPE_CHECKING, Optional, Union
|
from typing import TYPE_CHECKING, Optional, TypeAlias, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -69,6 +69,11 @@ else:
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
AttnMetadataDict: TypeAlias = dict[str, FlashAttentionMetadata]
|
||||||
|
# list when ubatching is enabled
|
||||||
|
PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict],
|
||||||
|
AttnMetadataDict]
|
||||||
|
|
||||||
|
|
||||||
class GPUModelRunner(LoRAModelRunnerMixin):
|
class GPUModelRunner(LoRAModelRunnerMixin):
|
||||||
|
|
||||||
@ -491,7 +496,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
def _prepare_inputs(
|
def _prepare_inputs(
|
||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
) -> tuple[dict[str, FlashAttentionMetadata], torch.Tensor,
|
ubatch_slices: Optional[list[tuple[
|
||||||
|
slice, slice]]] = None, # req_slice, token_slice
|
||||||
|
) -> tuple[PerLayerAttnMetadata, torch.Tensor,
|
||||||
Optional[SpecDecodeMetadata]]:
|
Optional[SpecDecodeMetadata]]:
|
||||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
assert total_num_scheduled_tokens > 0
|
assert total_num_scheduled_tokens > 0
|
||||||
@ -612,7 +619,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
common_attn_metadata = CommonAttentionMetadata(
|
common_attn_metadata = CommonAttentionMetadata(
|
||||||
query_start_loc=query_start_loc, seq_lens=seq_lens)
|
query_start_loc=query_start_loc, seq_lens=seq_lens)
|
||||||
|
|
||||||
attn_metadata: dict[str, FlashAttentionMetadata] = {}
|
attn_metadata: PerLayerAttnMetadata = {}
|
||||||
|
if ubatch_slices is not None:
|
||||||
|
attn_metadata = [dict() for _ in range(len(ubatch_slices))]
|
||||||
|
|
||||||
# Prepare the attention metadata for each KV cache group and make layers
|
# Prepare the attention metadata for each KV cache group and make layers
|
||||||
# in the same group share the same metadata.
|
# in the same group share the same metadata.
|
||||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||||
@ -629,15 +639,37 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.attn_metadata_builders[kv_cache_group_id],
|
self.attn_metadata_builders[kv_cache_group_id],
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_metadata_i = (
|
# Fill unused with -1. Needed for reshape_and_cache in full cuda
|
||||||
self.attn_metadata_builders[kv_cache_group_id].build(
|
# graph mode.
|
||||||
num_reqs=num_reqs,
|
if self.vllm_config.compilation_config.full_cuda_graph:
|
||||||
num_actual_tokens=total_num_scheduled_tokens,
|
self.input_batch.block_table[kv_cache_group_id]\
|
||||||
max_query_len=max_num_scheduled_tokens,
|
.slot_mapping.fill_(-1)
|
||||||
common_prefix_len=common_prefix_len,
|
|
||||||
common_attn_metadata=common_attn_metadata))
|
if ubatch_slices is not None:
|
||||||
for layer_name in kv_cache_group_spec.layer_names:
|
for ubid, (req_slice, token_slice) in enumerate(ubatch_slices):
|
||||||
attn_metadata[layer_name] = attn_metadata_i
|
attn_metadata_i = (
|
||||||
|
self.attn_metadata_builders[kv_cache_group_id].
|
||||||
|
build_slice(
|
||||||
|
max(tokens[req_slice]),
|
||||||
|
common_prefix_len=common_prefix_len,
|
||||||
|
common_attn_metadata=common_attn_metadata,
|
||||||
|
req_slice=req_slice,
|
||||||
|
token_slice=token_slice,
|
||||||
|
))
|
||||||
|
for layer_name in kv_cache_group_spec.layer_names:
|
||||||
|
assert type(attn_metadata) is list
|
||||||
|
attn_metadata[ubid][layer_name] = attn_metadata_i
|
||||||
|
else:
|
||||||
|
attn_metadata_i = (
|
||||||
|
self.attn_metadata_builders[kv_cache_group_id].build(
|
||||||
|
num_reqs=num_reqs,
|
||||||
|
num_actual_tokens=total_num_scheduled_tokens,
|
||||||
|
max_query_len=max_num_scheduled_tokens,
|
||||||
|
common_prefix_len=common_prefix_len,
|
||||||
|
common_attn_metadata=common_attn_metadata))
|
||||||
|
for layer_name in kv_cache_group_spec.layer_names:
|
||||||
|
assert type(attn_metadata) is dict
|
||||||
|
attn_metadata[layer_name] = attn_metadata_i
|
||||||
|
|
||||||
use_spec_decode = len(
|
use_spec_decode = len(
|
||||||
scheduler_output.scheduled_spec_decode_tokens) > 0
|
scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||||
@ -1057,10 +1089,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def sync_and_slice_intermediate_tensors(
|
def sync_and_slice_intermediate_tensors(
|
||||||
self, num_tokens: int, intermediate_tensors: IntermediateTensors,
|
self, tokens_slice: slice,
|
||||||
|
intermediate_tensors: IntermediateTensors,
|
||||||
sync_self: bool) -> IntermediateTensors:
|
sync_self: bool) -> IntermediateTensors:
|
||||||
|
|
||||||
assert self.intermediate_tensors is not None
|
assert self.intermediate_tensors is not None
|
||||||
|
num_tokens = tokens_slice.stop - tokens_slice.start
|
||||||
|
|
||||||
tp = self.vllm_config.parallel_config.tensor_parallel_size
|
tp = self.vllm_config.parallel_config.tensor_parallel_size
|
||||||
enabled_sp = self.vllm_config.compilation_config.pass_config. \
|
enabled_sp = self.vllm_config.compilation_config.pass_config. \
|
||||||
@ -1072,49 +1106,37 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
is_residual_scattered = tp > 1 and enabled_sp \
|
is_residual_scattered = tp > 1 and enabled_sp \
|
||||||
and num_tokens % tp == 0
|
and num_tokens % tp == 0
|
||||||
|
|
||||||
|
def copy_slice(is_scattered: bool) -> slice:
|
||||||
|
if is_scattered:
|
||||||
|
return slice(tokens_slice.start // tp, tokens_slice.stop // tp)
|
||||||
|
else:
|
||||||
|
return tokens_slice
|
||||||
|
|
||||||
# When sequence parallelism is enabled, the "residual" tensor is sharded
|
# When sequence parallelism is enabled, the "residual" tensor is sharded
|
||||||
# across tensor parallel ranks, so each rank only needs its own slice.
|
# across tensor parallel ranks, so each rank only needs its own slice.
|
||||||
if sync_self:
|
if sync_self:
|
||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
for k, v in intermediate_tensors.items():
|
for k, v in intermediate_tensors.items():
|
||||||
is_scattered = "residual" and is_residual_scattered
|
is_scattered = "residual" and is_residual_scattered
|
||||||
copy_len = num_tokens // tp if is_scattered else \
|
_copy_slice = copy_slice(is_scattered)
|
||||||
num_tokens
|
self.intermediate_tensors[k][_copy_slice].copy_(
|
||||||
self.intermediate_tensors[k][:copy_len].copy_(
|
v[_copy_slice], non_blocking=True)
|
||||||
v[:copy_len], non_blocking=True)
|
|
||||||
|
|
||||||
return IntermediateTensors({
|
return IntermediateTensors({
|
||||||
k:
|
k:
|
||||||
v[:num_tokens // tp]
|
v[copy_slice(k == "residual" and is_residual_scattered)]
|
||||||
if k == "residual" and is_residual_scattered else v[:num_tokens]
|
|
||||||
for k, v in self.intermediate_tensors.items()
|
for k, v in self.intermediate_tensors.items()
|
||||||
})
|
})
|
||||||
|
|
||||||
@torch.inference_mode()
|
def _get_model_inputs(self, tokens_slice: slice,
|
||||||
def execute_model(
|
scheduler_output: "SchedulerOutput"):
|
||||||
self,
|
num_tokens = tokens_slice.stop - tokens_slice.start
|
||||||
scheduler_output: "SchedulerOutput",
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
||||||
) -> Union[ModelRunnerOutput, IntermediateTensors]:
|
|
||||||
|
|
||||||
self._update_states(scheduler_output)
|
|
||||||
if not scheduler_output.total_num_scheduled_tokens:
|
|
||||||
if not has_kv_transfer_group():
|
|
||||||
# Return empty ModelRunnerOutput if there's no work to do.
|
|
||||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
|
||||||
|
|
||||||
return self.kv_connector_no_forward(scheduler_output)
|
|
||||||
|
|
||||||
# Prepare the decoder inputs.
|
|
||||||
attn_metadata, logits_indices, spec_decode_metadata = (
|
|
||||||
self._prepare_inputs(scheduler_output))
|
|
||||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
|
||||||
if (self.use_cuda_graph
|
if (self.use_cuda_graph
|
||||||
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
|
and num_tokens <= self.cudagraph_batch_sizes[-1]):
|
||||||
# Use piecewise CUDA graphs.
|
# Use piecewise CUDA graphs.
|
||||||
# Add padding to the batch size.
|
# Add padding to the batch size.
|
||||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
tokens_slice = \
|
||||||
num_scheduled_tokens)
|
slice(tokens_slice.start, self.vllm_config.pad_for_cudagraph(num_tokens))
|
||||||
else:
|
else:
|
||||||
# Eager mode.
|
# Eager mode.
|
||||||
# Pad tokens to multiple of tensor_parallel_size when
|
# Pad tokens to multiple of tensor_parallel_size when
|
||||||
@ -1123,9 +1145,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if self.vllm_config.compilation_config.pass_config. \
|
if self.vllm_config.compilation_config.pass_config. \
|
||||||
enable_sequence_parallelism and tp_size > 1:
|
enable_sequence_parallelism and tp_size > 1:
|
||||||
from vllm.utils import round_up
|
from vllm.utils import round_up
|
||||||
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
|
tokens_slice = slice(tokens_slice.start,
|
||||||
else:
|
round_up(num_tokens, tp_size))
|
||||||
num_input_tokens = num_scheduled_tokens
|
num_tokens = tokens_slice.stop - tokens_slice.start
|
||||||
|
|
||||||
# _prepare_inputs may reorder the batch, so we must gather multi
|
# _prepare_inputs may reorder the batch, so we must gather multi
|
||||||
# modal outputs after that to ensure the correct order
|
# modal outputs after that to ensure the correct order
|
||||||
@ -1140,51 +1162,94 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
||||||
# embeddings), we always use embeddings (rather than token ids)
|
# embeddings), we always use embeddings (rather than token ids)
|
||||||
# as input to the multimodal model, even when the input is text.
|
# as input to the multimodal model, even when the input is text.
|
||||||
input_ids = self.input_ids[:num_scheduled_tokens]
|
input_ids = self.input_ids[tokens_slice]
|
||||||
if mm_embeds:
|
if mm_embeds:
|
||||||
inputs_embeds = self.model.get_input_embeddings(
|
inputs_embeds = self.model.get_input_embeddings(
|
||||||
input_ids, mm_embeds)
|
input_ids, mm_embeds)
|
||||||
else:
|
else:
|
||||||
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
||||||
# TODO(woosuk): Avoid the copy. Optimize.
|
# TODO(woosuk): Avoid the copy. Optimize.
|
||||||
self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
|
self.inputs_embeds[tokens_slice].copy_(inputs_embeds)
|
||||||
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
inputs_embeds = self.inputs_embeds[tokens_slice]
|
||||||
input_ids = None
|
input_ids = None
|
||||||
else:
|
else:
|
||||||
# For text-only models, we use token ids as input.
|
# For text-only models, we use token ids as input.
|
||||||
# While it is possible to use embeddings as input just like the
|
# While it is possible to use embeddings as input just like the
|
||||||
# multimodal models, it is not desirable for performance since
|
# multimodal models, it is not desirable for performance since
|
||||||
# then the embedding layer is not included in the CUDA graph.
|
# then the embedding layer is not included in the CUDA graph.
|
||||||
input_ids = self.input_ids[:num_input_tokens]
|
input_ids = self.input_ids[tokens_slice]
|
||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
if self.uses_mrope:
|
if self.uses_mrope:
|
||||||
positions = self.mrope_positions[:, :num_input_tokens]
|
positions = self.mrope_positions[:, tokens_slice]
|
||||||
else:
|
else:
|
||||||
positions = self.positions[:num_input_tokens]
|
positions = self.positions[tokens_slice]
|
||||||
|
|
||||||
if get_pp_group().is_first_rank:
|
if get_pp_group().is_first_rank:
|
||||||
intermediate_tensors = None
|
intermediate_tensors = None
|
||||||
else:
|
else:
|
||||||
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
|
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
|
||||||
num_input_tokens, intermediate_tensors, True)
|
tokens_slice, intermediate_tensors, True)
|
||||||
|
return input_ids, positions, inputs_embeds, intermediate_tensors
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
scheduler_output: "SchedulerOutput",
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
) -> Union[ModelRunnerOutput, IntermediateTensors]:
|
||||||
|
ubatch_slices: Optional[list[tuple[slice, slice]]] = None
|
||||||
|
|
||||||
|
self._update_states(scheduler_output)
|
||||||
|
if not scheduler_output.total_num_scheduled_tokens:
|
||||||
|
if not has_kv_transfer_group():
|
||||||
|
# Return empty ModelRunnerOutput if there's no work to do.
|
||||||
|
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||||
|
|
||||||
|
return self.kv_connector_no_forward(scheduler_output)
|
||||||
|
|
||||||
|
# Prepare the decoder inputs.
|
||||||
|
attn_metadata, logits_indices, spec_decode_metadata = (
|
||||||
|
self._prepare_inputs(scheduler_output))
|
||||||
|
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
|
|
||||||
# Run the decoder.
|
# Run the decoder.
|
||||||
# Use persistent buffers for CUDA graphs.
|
# Use persistent buffers for CUDA graphs.
|
||||||
with set_forward_context(attn_metadata,
|
self.maybe_setup_kv_connector(scheduler_output)
|
||||||
self.vllm_config,
|
|
||||||
num_tokens=num_input_tokens):
|
|
||||||
self.maybe_setup_kv_connector(scheduler_output)
|
|
||||||
|
|
||||||
model_output = self.model(
|
if ubatch_slices is not None:
|
||||||
input_ids=input_ids,
|
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
||||||
positions=positions,
|
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||||
intermediate_tensors=intermediate_tensors,
|
self._get_model_inputs(tokens_slice, scheduler_output)
|
||||||
inputs_embeds=inputs_embeds,
|
num_input_token = tokens_slice.stop - tokens_slice.start
|
||||||
)
|
|
||||||
|
|
||||||
self.maybe_wait_for_kv_save()
|
with set_forward_context(attn_metadata[i],
|
||||||
finished_sending, finished_recving = (
|
self.vllm_config,
|
||||||
self.get_finished_kv_transfers(scheduler_output))
|
num_tokens=num_input_tokens):
|
||||||
|
|
||||||
|
model_output = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
positions=positions,
|
||||||
|
intermediate_tensors=intermediate_tensors,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||||
|
self._get_model_inputs(slice(0, num_scheduled_tokens),
|
||||||
|
scheduler_output)
|
||||||
|
|
||||||
|
with set_forward_context(attn_metadata,
|
||||||
|
self.vllm_config,
|
||||||
|
num_tokens=num_scheduled_tokens):
|
||||||
|
model_output = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
positions=positions,
|
||||||
|
intermediate_tensors=intermediate_tensors,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.maybe_wait_for_kv_save()
|
||||||
|
finished_sending, finished_recving = (
|
||||||
|
self.get_finished_kv_transfers(scheduler_output))
|
||||||
|
|
||||||
if self.use_aux_hidden_state_outputs:
|
if self.use_aux_hidden_state_outputs:
|
||||||
hidden_states, aux_hidden_states = model_output
|
hidden_states, aux_hidden_states = model_output
|
||||||
|
|||||||
5
vllm/v1/worker/ubatching.py
Normal file
5
vllm/v1/worker/ubatching.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
class UBatchContext:
|
||||||
|
|
||||||
|
def __init__(self, ubatch_id: int):
|
||||||
|
self.ubatch_id = ubatch_id
|
||||||
Loading…
x
Reference in New Issue
Block a user