mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-01 12:57:06 +08:00
Chunked prompt works!
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
This commit is contained in:
parent
0bddb6b9a5
commit
61bb55f3d5
@ -213,11 +213,12 @@ class Scheduler:
|
|||||||
num_new_tokens = self.block_size
|
num_new_tokens = self.block_size
|
||||||
computed_blocks.pop()
|
computed_blocks.pop()
|
||||||
|
|
||||||
|
# TODO: Remove
|
||||||
# If chunked prefill is not enabled, then breakout of the loop
|
# If chunked prefill is not enabled, then breakout of the loop
|
||||||
# when above budget.
|
# when above budget.
|
||||||
if (not self.scheduler_config.chunked_prefill_enabled
|
# if (not self.scheduler_config.chunked_prefill_enabled
|
||||||
and num_new_tokens > token_budget):
|
# and num_new_tokens > token_budget):
|
||||||
break
|
# break
|
||||||
|
|
||||||
num_new_tokens = min(num_new_tokens, token_budget)
|
num_new_tokens = min(num_new_tokens, token_budget)
|
||||||
assert num_new_tokens > 0
|
assert num_new_tokens > 0
|
||||||
|
|||||||
@ -9,23 +9,18 @@ import torch.distributed
|
|||||||
from vllm.config import CompilationLevel, VllmConfig
|
from vllm.config import CompilationLevel, VllmConfig
|
||||||
from vllm.distributed.parallel_state import graph_capture
|
from vllm.distributed.parallel_state import graph_capture
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.inputs import INPUT_REGISTRY
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||||
from vllm.sampling_params import SamplingType
|
|
||||||
from vllm.utils import DeviceMemoryProfiler, cdiv
|
from vllm.utils import DeviceMemoryProfiler, cdiv
|
||||||
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
|
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
|
||||||
FlashAttentionMetadata)
|
FlashAttentionMetadata)
|
||||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
|
||||||
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
|
|
||||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
|
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.utils import bind_kv_cache
|
from vllm.v1.utils import bind_kv_cache
|
||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
|
||||||
from vllm.v1.worker.model_runner_base import ExecutionMode, ModelRunnerBase
|
from vllm.v1.worker.model_runner_base import ExecutionMode, ModelRunnerBase
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -43,41 +38,9 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
):
|
):
|
||||||
super().__init__(vllm_config, device)
|
super().__init__(vllm_config, device)
|
||||||
|
|
||||||
# Persistent batch.
|
|
||||||
self.input_batch = InputBatch(
|
|
||||||
max_num_reqs=self.max_num_reqs,
|
|
||||||
max_model_len=self.max_model_len,
|
|
||||||
max_num_blocks_per_req=self.max_num_blocks_per_req,
|
|
||||||
device=self.device,
|
|
||||||
pin_memory=self.pin_memory,
|
|
||||||
vocab_size=self.model_config.get_vocab_size(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Request states.
|
|
||||||
self.requests: Dict[str, CachedRequestState] = {}
|
|
||||||
|
|
||||||
# KV caches for forward pass
|
# KV caches for forward pass
|
||||||
self.kv_caches: List[torch.Tensor] = []
|
self.kv_caches: List[torch.Tensor] = []
|
||||||
|
|
||||||
# Multi-modal data support
|
|
||||||
self.input_registry = INPUT_REGISTRY
|
|
||||||
self.mm_registry = MULTIMODAL_REGISTRY
|
|
||||||
|
|
||||||
# NOTE: Initialized input mapper is only used for processing dummy
|
|
||||||
# multimodal data into multimodal kwargs for GPU memory profiling.
|
|
||||||
self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config)
|
|
||||||
self.mm_input_mapper_profiling.use_cache = False
|
|
||||||
|
|
||||||
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
|
|
||||||
model_config=self.model_config,
|
|
||||||
scheduler_config=self.scheduler_config,
|
|
||||||
)
|
|
||||||
self.max_num_encoder_input_tokens = encoder_compute_budget
|
|
||||||
self.encoder_cache_size = encoder_cache_size
|
|
||||||
|
|
||||||
# req_id -> (input_id -> encoder_output)
|
|
||||||
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
|
|
||||||
|
|
||||||
self.use_cuda_graph = (self.vllm_config.compilation_config.level
|
self.use_cuda_graph = (self.vllm_config.compilation_config.level
|
||||||
== CompilationLevel.PIECEWISE
|
== CompilationLevel.PIECEWISE
|
||||||
and not self.model_config.enforce_eager)
|
and not self.model_config.enforce_eager)
|
||||||
@ -160,132 +123,6 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
pin_memory=self.pin_memory)
|
pin_memory=self.pin_memory)
|
||||||
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
||||||
|
|
||||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
|
||||||
# Remove stopped requests from the cached states.
|
|
||||||
# Keep the states of the pre-empted requests.
|
|
||||||
for req_id in scheduler_output.finished_req_ids:
|
|
||||||
self.requests.pop(req_id, None)
|
|
||||||
self.encoder_cache.pop(req_id, None)
|
|
||||||
|
|
||||||
# Free the cached encoder outputs.
|
|
||||||
for req_id, input_id in scheduler_output.free_encoder_input_ids:
|
|
||||||
encoder_outputs = self.encoder_cache.get(req_id)
|
|
||||||
if encoder_outputs is not None:
|
|
||||||
encoder_outputs.pop(input_id, None)
|
|
||||||
if not encoder_outputs:
|
|
||||||
self.encoder_cache.pop(req_id, None)
|
|
||||||
|
|
||||||
# Remove the requests from the persistent batch.
|
|
||||||
stopped_req_ids = set().union(
|
|
||||||
scheduler_output.preempted_req_ids,
|
|
||||||
scheduler_output.finished_req_ids,
|
|
||||||
)
|
|
||||||
removed_req_indices: List[int] = []
|
|
||||||
for req_id in stopped_req_ids:
|
|
||||||
req_index = self.input_batch.remove_request(req_id)
|
|
||||||
if req_index is not None:
|
|
||||||
removed_req_indices.append(req_index)
|
|
||||||
|
|
||||||
# Update the states of the running requests.
|
|
||||||
for req_data in scheduler_output.scheduled_running_reqs:
|
|
||||||
req_id = req_data.req_id
|
|
||||||
req_state = self.requests[req_id]
|
|
||||||
req_index = self.input_batch.req_id_to_index[req_id]
|
|
||||||
|
|
||||||
# Update the num_computed_tokens.
|
|
||||||
req_state.num_computed_tokens = req_data.num_computed_tokens
|
|
||||||
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
|
||||||
req_data.num_computed_tokens)
|
|
||||||
|
|
||||||
# Update the block table.
|
|
||||||
num_new_blocks = len(req_data.new_block_ids)
|
|
||||||
if num_new_blocks == 0:
|
|
||||||
continue
|
|
||||||
start_index = len(req_state.block_ids)
|
|
||||||
req_state.block_ids.extend(req_data.new_block_ids)
|
|
||||||
self.input_batch.block_table.append_row(req_index, start_index,
|
|
||||||
req_data.new_block_ids)
|
|
||||||
|
|
||||||
req_ids_to_add: List[str] = []
|
|
||||||
# Add new requests to the cached states.
|
|
||||||
for new_req_data in scheduler_output.scheduled_new_reqs:
|
|
||||||
req_id = new_req_data.req_id
|
|
||||||
sampling_params = new_req_data.sampling_params
|
|
||||||
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
|
||||||
generator = torch.Generator(device=self.device)
|
|
||||||
generator.manual_seed(sampling_params.seed)
|
|
||||||
else:
|
|
||||||
generator = None
|
|
||||||
|
|
||||||
self.requests[req_id] = CachedRequestState(
|
|
||||||
req_id=req_id,
|
|
||||||
prompt_token_ids=new_req_data.prompt_token_ids,
|
|
||||||
prompt=new_req_data.prompt,
|
|
||||||
mm_inputs=new_req_data.mm_inputs,
|
|
||||||
mm_positions=new_req_data.mm_positions,
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
generator=generator,
|
|
||||||
block_ids=new_req_data.block_ids,
|
|
||||||
num_computed_tokens=new_req_data.num_computed_tokens,
|
|
||||||
output_token_ids=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
|
||||||
if self.model_config.uses_mrope:
|
|
||||||
image_grid_thw = []
|
|
||||||
video_grid_thw = []
|
|
||||||
for mm_input in self.requests[req_id].mm_inputs:
|
|
||||||
if mm_input.get("image_grid_thw") is not None:
|
|
||||||
image_grid_thw.extend(
|
|
||||||
mm_input["image_grid_thw"].tolist())
|
|
||||||
if mm_input.get("video_grid_thw") is not None:
|
|
||||||
video_grid_thw.extend(
|
|
||||||
mm_input["video_grid_thw"].tolist())
|
|
||||||
|
|
||||||
hf_config = self.model_config.hf_config
|
|
||||||
|
|
||||||
self.requests[req_id].mrope_positions, \
|
|
||||||
self.requests[req_id].mrope_position_delta = \
|
|
||||||
MRotaryEmbedding.get_input_positions_tensor(
|
|
||||||
self.requests[req_id].prompt_token_ids,
|
|
||||||
image_grid_thw=image_grid_thw,
|
|
||||||
video_grid_thw=video_grid_thw,
|
|
||||||
image_token_id=hf_config.image_token_id,
|
|
||||||
video_token_id=hf_config.video_token_id,
|
|
||||||
vision_start_token_id=hf_config.vision_start_token_id,
|
|
||||||
vision_end_token_id=hf_config.vision_end_token_id,
|
|
||||||
spatial_merge_size=hf_config.vision_config.
|
|
||||||
spatial_merge_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
req_ids_to_add.append(req_id)
|
|
||||||
|
|
||||||
# Update the cached states of the resumed requests.
|
|
||||||
for res_req_data in scheduler_output.scheduled_resumed_reqs:
|
|
||||||
req_id = res_req_data.req_id
|
|
||||||
req_state = self.requests[req_id]
|
|
||||||
|
|
||||||
req_state.block_ids = res_req_data.block_ids
|
|
||||||
req_state.num_computed_tokens = res_req_data.num_computed_tokens
|
|
||||||
req_ids_to_add.append(req_id)
|
|
||||||
|
|
||||||
# Add the new or resumed requests to the persistent batch.
|
|
||||||
# The smaller empty indices are filled first.
|
|
||||||
removed_req_indices = sorted(removed_req_indices, reverse=True)
|
|
||||||
for req_id in req_ids_to_add:
|
|
||||||
req_state = self.requests[req_id]
|
|
||||||
if removed_req_indices:
|
|
||||||
# Fill the empty index.
|
|
||||||
req_index = removed_req_indices.pop()
|
|
||||||
else:
|
|
||||||
# Append to the end.
|
|
||||||
req_index = None
|
|
||||||
self.input_batch.add_request(req_state, req_index)
|
|
||||||
|
|
||||||
# Condense the batched states if there are empty indices.
|
|
||||||
if removed_req_indices:
|
|
||||||
self.input_batch.condense(removed_req_indices)
|
|
||||||
|
|
||||||
def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
|
def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
|
||||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
assert total_num_scheduled_tokens > 0
|
assert total_num_scheduled_tokens > 0
|
||||||
@ -665,7 +502,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
) -> ModelRunnerOutput:
|
) -> ModelRunnerOutput:
|
||||||
assert self.model is not None
|
assert self.model is not None
|
||||||
|
|
||||||
self._update_states(scheduler_output)
|
self.update_states(scheduler_output)
|
||||||
|
|
||||||
if self.is_multimodal_model:
|
if self.is_multimodal_model:
|
||||||
# Run the multimodal encoder if any.
|
# Run the multimodal encoder if any.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import enum
|
import enum
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
@ -8,11 +8,18 @@ import torch.nn as nn
|
|||||||
from vllm.attention.backends.abstract import AttentionType
|
from vllm.attention.backends.abstract import AttentionType
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.inputs import INPUT_REGISTRY
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||||
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
from vllm.sampling_params import SamplingType
|
||||||
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
|
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
|
||||||
|
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||||
|
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
|
||||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||||
KVCacheSpec)
|
KVCacheSpec)
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.v1.core.scheduler import SchedulerOutput
|
from vllm.v1.core.scheduler import SchedulerOutput
|
||||||
@ -75,6 +82,164 @@ class ModelRunnerBase:
|
|||||||
|
|
||||||
self.model: Optional[nn.Module] = None
|
self.model: Optional[nn.Module] = None
|
||||||
|
|
||||||
|
# Persistent batch.
|
||||||
|
self.input_batch = InputBatch(
|
||||||
|
max_num_reqs=self.max_num_reqs,
|
||||||
|
max_model_len=self.max_model_len,
|
||||||
|
max_num_blocks_per_req=self.max_num_blocks_per_req,
|
||||||
|
device=self.device,
|
||||||
|
pin_memory=self.pin_memory,
|
||||||
|
vocab_size=self.model_config.get_vocab_size(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Request states.
|
||||||
|
self.requests: Dict[str, CachedRequestState] = {}
|
||||||
|
|
||||||
|
# Multi-modal data support
|
||||||
|
self.input_registry = INPUT_REGISTRY
|
||||||
|
self.mm_registry = MULTIMODAL_REGISTRY
|
||||||
|
|
||||||
|
# NOTE: Initialized input mapper is only used for processing dummy
|
||||||
|
# multimodal data into multimodal kwargs for GPU memory profiling.
|
||||||
|
self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config)
|
||||||
|
self.mm_input_mapper_profiling.use_cache = False
|
||||||
|
|
||||||
|
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
|
||||||
|
model_config=self.model_config,
|
||||||
|
scheduler_config=self.scheduler_config,
|
||||||
|
)
|
||||||
|
self.max_num_encoder_input_tokens = encoder_compute_budget
|
||||||
|
self.encoder_cache_size = encoder_cache_size
|
||||||
|
|
||||||
|
# req_id -> (input_id -> encoder_output)
|
||||||
|
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
|
||||||
|
|
||||||
|
def update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||||
|
# Remove stopped requests from the cached states.
|
||||||
|
# Keep the states of the pre-empted requests.
|
||||||
|
for req_id in scheduler_output.finished_req_ids:
|
||||||
|
self.requests.pop(req_id, None)
|
||||||
|
self.encoder_cache.pop(req_id, None)
|
||||||
|
|
||||||
|
# Free the cached encoder outputs.
|
||||||
|
for req_id, input_id in scheduler_output.free_encoder_input_ids:
|
||||||
|
encoder_outputs = self.encoder_cache.get(req_id)
|
||||||
|
if encoder_outputs is not None:
|
||||||
|
encoder_outputs.pop(input_id, None)
|
||||||
|
if not encoder_outputs:
|
||||||
|
self.encoder_cache.pop(req_id, None)
|
||||||
|
|
||||||
|
# Remove the requests from the persistent batch.
|
||||||
|
stopped_req_ids = set().union(
|
||||||
|
scheduler_output.preempted_req_ids,
|
||||||
|
scheduler_output.finished_req_ids,
|
||||||
|
)
|
||||||
|
removed_req_indices: List[int] = []
|
||||||
|
for req_id in stopped_req_ids:
|
||||||
|
req_index = self.input_batch.remove_request(req_id)
|
||||||
|
if req_index is not None:
|
||||||
|
removed_req_indices.append(req_index)
|
||||||
|
|
||||||
|
# Update the states of the running requests.
|
||||||
|
for req_data in scheduler_output.scheduled_running_reqs:
|
||||||
|
req_id = req_data.req_id
|
||||||
|
req_state = self.requests[req_id]
|
||||||
|
req_index = self.input_batch.req_id_to_index[req_id]
|
||||||
|
|
||||||
|
# Update the num_computed_tokens.
|
||||||
|
req_state.num_computed_tokens = req_data.num_computed_tokens
|
||||||
|
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
||||||
|
req_data.num_computed_tokens)
|
||||||
|
|
||||||
|
# Update the block table.
|
||||||
|
num_new_blocks = len(req_data.new_block_ids)
|
||||||
|
if num_new_blocks == 0:
|
||||||
|
continue
|
||||||
|
start_index = len(req_state.block_ids)
|
||||||
|
req_state.block_ids.extend(req_data.new_block_ids)
|
||||||
|
self.input_batch.block_table.append_row(req_index, start_index,
|
||||||
|
req_data.new_block_ids)
|
||||||
|
|
||||||
|
req_ids_to_add: List[str] = []
|
||||||
|
# Add new requests to the cached states.
|
||||||
|
for new_req_data in scheduler_output.scheduled_new_reqs:
|
||||||
|
req_id = new_req_data.req_id
|
||||||
|
sampling_params = new_req_data.sampling_params
|
||||||
|
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
||||||
|
generator = torch.Generator(device=self.device)
|
||||||
|
generator.manual_seed(sampling_params.seed)
|
||||||
|
else:
|
||||||
|
generator = None
|
||||||
|
|
||||||
|
self.requests[req_id] = CachedRequestState(
|
||||||
|
req_id=req_id,
|
||||||
|
prompt_token_ids=new_req_data.prompt_token_ids,
|
||||||
|
prompt=new_req_data.prompt,
|
||||||
|
mm_inputs=new_req_data.mm_inputs,
|
||||||
|
mm_positions=new_req_data.mm_positions,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
generator=generator,
|
||||||
|
block_ids=new_req_data.block_ids,
|
||||||
|
num_computed_tokens=new_req_data.num_computed_tokens,
|
||||||
|
output_token_ids=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||||
|
if self.model_config.uses_mrope:
|
||||||
|
image_grid_thw = []
|
||||||
|
video_grid_thw = []
|
||||||
|
for mm_input in self.requests[req_id].mm_inputs:
|
||||||
|
if mm_input.get("image_grid_thw") is not None:
|
||||||
|
image_grid_thw.extend(
|
||||||
|
mm_input["image_grid_thw"].tolist())
|
||||||
|
if mm_input.get("video_grid_thw") is not None:
|
||||||
|
video_grid_thw.extend(
|
||||||
|
mm_input["video_grid_thw"].tolist())
|
||||||
|
|
||||||
|
hf_config = self.model_config.hf_config
|
||||||
|
|
||||||
|
self.requests[req_id].mrope_positions, \
|
||||||
|
self.requests[req_id].mrope_position_delta = \
|
||||||
|
MRotaryEmbedding.get_input_positions_tensor(
|
||||||
|
self.requests[req_id].prompt_token_ids,
|
||||||
|
image_grid_thw=image_grid_thw,
|
||||||
|
video_grid_thw=video_grid_thw,
|
||||||
|
image_token_id=hf_config.image_token_id,
|
||||||
|
video_token_id=hf_config.video_token_id,
|
||||||
|
vision_start_token_id=hf_config.vision_start_token_id,
|
||||||
|
vision_end_token_id=hf_config.vision_end_token_id,
|
||||||
|
spatial_merge_size=hf_config.vision_config.
|
||||||
|
spatial_merge_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
req_ids_to_add.append(req_id)
|
||||||
|
|
||||||
|
# Update the cached states of the resumed requests.
|
||||||
|
for res_req_data in scheduler_output.scheduled_resumed_reqs:
|
||||||
|
req_id = res_req_data.req_id
|
||||||
|
req_state = self.requests[req_id]
|
||||||
|
|
||||||
|
req_state.block_ids = res_req_data.block_ids
|
||||||
|
req_state.num_computed_tokens = res_req_data.num_computed_tokens
|
||||||
|
req_ids_to_add.append(req_id)
|
||||||
|
|
||||||
|
# Add the new or resumed requests to the persistent batch.
|
||||||
|
# The smaller empty indices are filled first.
|
||||||
|
removed_req_indices = sorted(removed_req_indices, reverse=True)
|
||||||
|
for req_id in req_ids_to_add:
|
||||||
|
req_state = self.requests[req_id]
|
||||||
|
if removed_req_indices:
|
||||||
|
# Fill the empty index.
|
||||||
|
req_index = removed_req_indices.pop()
|
||||||
|
else:
|
||||||
|
# Append to the end.
|
||||||
|
req_index = None
|
||||||
|
self.input_batch.add_request(req_state, req_index)
|
||||||
|
|
||||||
|
# Condense the batched states if there are empty indices.
|
||||||
|
if removed_req_indices:
|
||||||
|
self.input_batch.condense(removed_req_indices)
|
||||||
|
|
||||||
def get_model(self) -> nn.Module:
|
def get_model(self) -> nn.Module:
|
||||||
assert self.model is not None
|
assert self.model is not None
|
||||||
return self.model
|
return self.model
|
||||||
|
|||||||
@ -15,7 +15,6 @@ from vllm.config import VllmConfig
|
|||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.sampling_params import SamplingType
|
|
||||||
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
|
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
|
||||||
PallasMetadata)
|
PallasMetadata)
|
||||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
|
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
|
||||||
@ -40,25 +39,24 @@ _MAX_NUM_SAMPLES = 128
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PrefillInputData:
|
class PromptInputData:
|
||||||
|
|
||||||
request_ids: List
|
req_ids: List
|
||||||
prompt_lens: List
|
prompt_lens: List
|
||||||
token_ids: List
|
input_tokens: List
|
||||||
position_ids: List
|
input_positions: List
|
||||||
attn_metadata: List
|
attn_metadata: List
|
||||||
|
|
||||||
def zipped(self):
|
def zipped(self):
|
||||||
return zip(self.request_ids, self.prompt_lens, self.token_ids,
|
return zip(self.req_ids, self.prompt_lens, self.input_tokens,
|
||||||
self.position_ids, self.attn_metadata)
|
self.input_positions, self.attn_metadata)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DecodeInputData:
|
class DecodeInputData:
|
||||||
|
req_ids: List
|
||||||
num_decodes: int
|
input_tokens: Optional[torch.Tensor] = None
|
||||||
token_ids: Optional[torch.Tensor] = None
|
input_positions: Optional[torch.Tensor] = None
|
||||||
position_ids: Optional[torch.Tensor] = None
|
|
||||||
attn_metadata: Optional[PallasMetadata] = None
|
attn_metadata: Optional[PallasMetadata] = None
|
||||||
|
|
||||||
|
|
||||||
@ -88,158 +86,105 @@ class TPUModelRunner(ModelRunnerBase):
|
|||||||
self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = []
|
self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = []
|
||||||
|
|
||||||
# Used to initialize positions for the individual prefills
|
# Used to initialize positions for the individual prefills
|
||||||
self.prefill_positions = torch.tensor(range(self.max_model_len),
|
self.prefill_input_positions = torch.tensor(range(self.max_model_len),
|
||||||
device="cpu",
|
device="cpu",
|
||||||
dtype=torch.int32).reshape(
|
dtype=torch.int32).reshape(
|
||||||
1, -1)
|
1, -1)
|
||||||
|
|
||||||
# Used to indicate how many prefills there are for each scheduler
|
def _prepare_prompt_inputs(
|
||||||
# iteration
|
|
||||||
self.num_new_reqs: int = 0
|
|
||||||
|
|
||||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
|
||||||
# Remove stopped requests from the cached states.
|
|
||||||
# Keep the states of the pre-empted requests.
|
|
||||||
for req_id in scheduler_output.finished_req_ids:
|
|
||||||
self.requests.pop(req_id, None)
|
|
||||||
|
|
||||||
# Remove the requests from the persistent batch.
|
|
||||||
stopped_req_ids = set().union(
|
|
||||||
scheduler_output.preempted_req_ids,
|
|
||||||
scheduler_output.finished_req_ids,
|
|
||||||
)
|
|
||||||
removed_req_indices: List[int] = []
|
|
||||||
for req_id in stopped_req_ids:
|
|
||||||
req_index = self.input_batch.remove_request(req_id)
|
|
||||||
if req_index is not None:
|
|
||||||
removed_req_indices.append(req_index)
|
|
||||||
|
|
||||||
# Update the states of the running requests.
|
|
||||||
for req_data in scheduler_output.scheduled_running_reqs:
|
|
||||||
req_id = req_data.req_id
|
|
||||||
req_state = self.requests[req_id]
|
|
||||||
req_index = self.input_batch.req_id_to_index[req_id]
|
|
||||||
|
|
||||||
# Update the num_computed_tokens.
|
|
||||||
req_state.num_computed_tokens = req_data.num_computed_tokens
|
|
||||||
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
|
||||||
req_data.num_computed_tokens)
|
|
||||||
|
|
||||||
# Update the block table.
|
|
||||||
num_new_blocks = len(req_data.new_block_ids)
|
|
||||||
if num_new_blocks == 0:
|
|
||||||
continue
|
|
||||||
start_index = len(req_state.block_ids)
|
|
||||||
req_state.block_ids.extend(req_data.new_block_ids)
|
|
||||||
self.input_batch.block_table.append_row(req_index, start_index,
|
|
||||||
req_data.new_block_ids)
|
|
||||||
|
|
||||||
req_ids_to_add: List[str] = []
|
|
||||||
# Add new requests to the cached states.
|
|
||||||
for new_req_data in scheduler_output.scheduled_new_reqs:
|
|
||||||
req_id = new_req_data.req_id
|
|
||||||
sampling_params = new_req_data.sampling_params
|
|
||||||
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
|
||||||
generator = torch.Generator(device=self.device)
|
|
||||||
generator.manual_seed(sampling_params.seed)
|
|
||||||
else:
|
|
||||||
generator = None
|
|
||||||
|
|
||||||
self.requests[req_id] = CachedRequestState(
|
|
||||||
req_id=req_id,
|
|
||||||
prompt_token_ids=new_req_data.prompt_token_ids,
|
|
||||||
prompt=new_req_data.prompt,
|
|
||||||
mm_inputs=new_req_data.mm_inputs,
|
|
||||||
mm_positions=new_req_data.mm_positions,
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
generator=generator,
|
|
||||||
block_ids=new_req_data.block_ids,
|
|
||||||
num_computed_tokens=new_req_data.num_computed_tokens,
|
|
||||||
output_token_ids=[],
|
|
||||||
)
|
|
||||||
req_ids_to_add.append(req_id)
|
|
||||||
|
|
||||||
# Update the cached states of the resumed requests.
|
|
||||||
for res_req_data in scheduler_output.scheduled_resumed_reqs:
|
|
||||||
req_id = res_req_data.req_id
|
|
||||||
req_state = self.requests[req_id]
|
|
||||||
|
|
||||||
req_state.block_ids = res_req_data.block_ids
|
|
||||||
req_state.num_computed_tokens = res_req_data.num_computed_tokens
|
|
||||||
req_ids_to_add.append(req_id)
|
|
||||||
|
|
||||||
# For TPU, we keep all of the decode requests before the
|
|
||||||
# prefill requests in the batch sequence.
|
|
||||||
# 1. First condense, so all decodes move to start
|
|
||||||
# 2. Then add new prefills to the end of the batch
|
|
||||||
removed_req_indices = sorted(removed_req_indices, reverse=True)
|
|
||||||
if removed_req_indices:
|
|
||||||
self.input_batch.condense(removed_req_indices)
|
|
||||||
|
|
||||||
for req_id in req_ids_to_add:
|
|
||||||
req_state = self.requests[req_id]
|
|
||||||
self.input_batch.add_request(req_state, None) # Append last
|
|
||||||
|
|
||||||
self.num_new_reqs = len(req_ids_to_add)
|
|
||||||
|
|
||||||
def _prepare_prefill_inputs(
|
|
||||||
self,
|
self,
|
||||||
num_scheduled_tokens: List[int],
|
scheduler_output: "SchedulerOutput",
|
||||||
) -> PrefillInputData:
|
) -> PromptInputData:
|
||||||
# Each prefill run separately with shape [1, padded_prompt_len].
|
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
# So we create lists that will be used in execute_model().
|
assert total_num_scheduled_tokens > 0
|
||||||
|
|
||||||
prefill_request_ids = []
|
|
||||||
prefill_prompt_lens = []
|
|
||||||
prefill_token_ids = []
|
|
||||||
prefill_position_ids = []
|
|
||||||
prefill_attn_metadata = []
|
|
||||||
|
|
||||||
# DECODES are the first num_decodes REQUESTS.
|
|
||||||
# PREFILLS are the next num_reqs - num_decodes REQUESTS.
|
|
||||||
num_reqs = self.input_batch.num_reqs
|
num_reqs = self.input_batch.num_reqs
|
||||||
num_decodes = num_reqs - self.num_new_reqs
|
assert num_reqs > 0
|
||||||
for idx in range(num_decodes, num_reqs):
|
|
||||||
req_id = self.input_batch.req_ids[idx]
|
|
||||||
prefill_request_ids.append(req_id)
|
|
||||||
|
|
||||||
prompt_len = num_scheduled_tokens[idx]
|
# OPTIMIZATION: Start copying the block table first.
|
||||||
prefill_prompt_lens.append(prompt_len)
|
# This way, we can overlap the copy with the following CPU operations.
|
||||||
|
self.input_batch.block_table.commit(num_reqs)
|
||||||
|
|
||||||
# STATIC SHAPE: prefills are padded to the next power of 2.
|
req_ids = []
|
||||||
|
prompt_lens = []
|
||||||
|
input_tokens_list = []
|
||||||
|
input_positions_list = []
|
||||||
|
attn_metadata_list = []
|
||||||
|
for req_id in self.input_batch.req_ids[:num_reqs]:
|
||||||
|
assert req_id is not None
|
||||||
|
req_index = self.input_batch.req_id_to_index[req_id]
|
||||||
|
req_state = self.requests[req_id]
|
||||||
|
|
||||||
|
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
||||||
|
req_id]
|
||||||
|
num_computed_tokens = req_state.num_computed_tokens
|
||||||
|
num_prompt_tokens = len(req_state.prompt_token_ids)
|
||||||
|
|
||||||
|
# Detect whether this is a prompt (can be full or chunked)
|
||||||
|
if num_computed_tokens >= num_prompt_tokens:
|
||||||
|
# This is a decode => Skip
|
||||||
|
continue
|
||||||
|
|
||||||
|
# This is a prompt
|
||||||
|
req_ids.append(req_id)
|
||||||
|
|
||||||
|
# Prompt len
|
||||||
|
prompt_len = num_scheduled_tokens
|
||||||
|
prompt_lens.append(prompt_len)
|
||||||
padded_prompt_len = _get_padded_prefill_len(prompt_len)
|
padded_prompt_len = _get_padded_prefill_len(prompt_len)
|
||||||
assert padded_prompt_len <= self.max_model_len
|
assert padded_prompt_len <= self.max_model_len
|
||||||
|
|
||||||
# TOKEN_IDS.
|
# Seq len
|
||||||
token_ids = torch.from_numpy(self.input_batch.token_ids_cpu[
|
seq_len = num_computed_tokens + prompt_len
|
||||||
idx, :padded_prompt_len].reshape(1, -1))
|
padded_seq_len = num_computed_tokens + padded_prompt_len
|
||||||
token_ids[:, prompt_len:] = 0
|
|
||||||
prefill_token_ids.append(token_ids.to(self.device))
|
|
||||||
|
|
||||||
# POSITIONS.
|
# Input tokens
|
||||||
positions = self.prefill_positions[:, :padded_prompt_len].clone()
|
input_tokens = torch.from_numpy(self.input_batch.token_ids_cpu[
|
||||||
positions[:, prompt_len:] = 0
|
req_index, num_computed_tokens:padded_seq_len].reshape(1, -1))
|
||||||
prefill_position_ids.append(positions.to(self.device))
|
input_tokens[:, prompt_len:] = 0
|
||||||
|
input_tokens_list.append(input_tokens.to(self.device))
|
||||||
|
|
||||||
# SLOT_MAPPING.
|
# Input positions
|
||||||
# The "slot" is the "physical index" of a token in the KV cache.
|
input_positions = self.prefill_input_positions[:,
|
||||||
# Look up the block_idx in the block table (logical<>physical map)
|
num_computed_tokens:
|
||||||
# to compute this.
|
padded_seq_len].clone(
|
||||||
|
)
|
||||||
|
input_positions[:, prompt_len:] = 0
|
||||||
|
input_positions_list.append(input_positions.to(self.device))
|
||||||
|
|
||||||
|
# Slot mapping
|
||||||
block_table_cpu_tensor = \
|
block_table_cpu_tensor = \
|
||||||
self.input_batch.block_table.get_cpu_tensor()
|
self.input_batch.block_table.get_cpu_tensor()
|
||||||
|
block_numbers = block_table_cpu_tensor[req_index,
|
||||||
block_numbers = block_table_cpu_tensor[idx, positions //
|
input_positions //
|
||||||
self.block_size].reshape(
|
self.block_size].reshape(
|
||||||
1, -1)
|
1, -1)
|
||||||
|
|
||||||
block_offsets = positions % self.block_size
|
block_offsets = input_positions % self.block_size
|
||||||
slot_mapping = block_numbers * self.block_size + block_offsets
|
slot_mapping = block_numbers * self.block_size + block_offsets
|
||||||
# Set an out of range value for the padding tokens so that they
|
|
||||||
# are ignored when inserting into the KV cache.
|
|
||||||
slot_mapping[:, prompt_len:] = _PAD_SLOT_ID
|
slot_mapping[:, prompt_len:] = _PAD_SLOT_ID
|
||||||
slot_mapping = slot_mapping.long()
|
slot_mapping = slot_mapping.long()
|
||||||
|
|
||||||
prefill_attn_metadata.append(
|
# Block table
|
||||||
|
block_table = None
|
||||||
|
if num_computed_tokens > 0:
|
||||||
|
block_table = self.input_batch.block_table.get_device_tensor()
|
||||||
|
block_table = block_table[req_index].unsqueeze(0)
|
||||||
|
|
||||||
|
# Context len
|
||||||
|
context_len = 0
|
||||||
|
if num_computed_tokens > 0:
|
||||||
|
context_len = seq_len
|
||||||
|
context_lens = torch.tensor([context_len],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu")
|
||||||
|
|
||||||
|
# Effective query len
|
||||||
|
effective_query_lens = torch.tensor([prompt_len],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu")
|
||||||
|
|
||||||
|
# Attn metadata
|
||||||
|
attn_metadata_list.append(
|
||||||
PallasMetadata(
|
PallasMetadata(
|
||||||
num_prefills=1,
|
num_prefills=1,
|
||||||
num_prefill_tokens=0, # NOTE: This is not used.
|
num_prefill_tokens=0, # NOTE: This is not used.
|
||||||
@ -247,208 +192,200 @@ class TPUModelRunner(ModelRunnerBase):
|
|||||||
slot_mapping=slot_mapping.to(self.device),
|
slot_mapping=slot_mapping.to(self.device),
|
||||||
multi_modal_placeholder_index_maps=None,
|
multi_modal_placeholder_index_maps=None,
|
||||||
enable_kv_scales_calculation=True,
|
enable_kv_scales_calculation=True,
|
||||||
block_tables=None,
|
block_tables=block_table,
|
||||||
context_lens=None,
|
context_lens=context_lens.to(self.device),
|
||||||
effective_query_lens=None,
|
effective_query_lens=effective_query_lens.to(self.device),
|
||||||
))
|
))
|
||||||
|
|
||||||
return PrefillInputData(
|
return PromptInputData(
|
||||||
request_ids=prefill_request_ids,
|
req_ids=req_ids,
|
||||||
prompt_lens=prefill_prompt_lens,
|
prompt_lens=prompt_lens,
|
||||||
token_ids=prefill_token_ids,
|
input_tokens=input_tokens_list,
|
||||||
position_ids=prefill_position_ids,
|
input_positions=input_positions_list,
|
||||||
attn_metadata=prefill_attn_metadata,
|
attn_metadata=attn_metadata_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _prepare_decode_inputs(self) -> DecodeInputData:
|
def _prepare_decode_inputs(
|
||||||
# Decodes run as one single padded batch with shape [batch, 1]
|
self,
|
||||||
#
|
scheduler_output: "SchedulerOutput",
|
||||||
# We need to set _PAD_SLOT_ID for the padding tokens in the
|
) -> DecodeInputData:
|
||||||
# slot_mapping, such that the attention KV cache insertion
|
|
||||||
# logic knows to ignore those indices. Otherwise, the
|
|
||||||
# padding data can be dummy since we have a causal mask.
|
|
||||||
|
|
||||||
# DECODES are the first num_decodes REQUESTS.
|
|
||||||
# PREFILLS are the next num_reqs - num_decodes REQUESTS.
|
|
||||||
num_reqs = self.input_batch.num_reqs
|
|
||||||
num_decodes = num_reqs - self.num_new_reqs
|
|
||||||
|
|
||||||
if num_decodes == 0:
|
|
||||||
return DecodeInputData(num_decodes=0)
|
|
||||||
|
|
||||||
# PAD FOR STATIC SHAPES.
|
|
||||||
padded_batch_size = _get_padded_batch_size(num_decodes)
|
|
||||||
|
|
||||||
# POSITIONS. [batch, 1]
|
|
||||||
# We slice at the end, since we use the positions for gathering.
|
|
||||||
positions = torch.from_numpy(
|
|
||||||
self.input_batch.num_computed_tokens_cpu.reshape(-1, 1))
|
|
||||||
index = positions.to(torch.int64)
|
|
||||||
index[num_decodes:] = 0
|
|
||||||
positions = positions[:padded_batch_size]
|
|
||||||
positions[num_decodes:] = 0
|
|
||||||
|
|
||||||
# TOKEN_IDS. [batch, 1]
|
|
||||||
token_ids = torch.gather(
|
|
||||||
input=torch.from_numpy(self.input_batch.token_ids_cpu),
|
|
||||||
dim=1,
|
|
||||||
index=index,
|
|
||||||
)[:padded_batch_size].to(torch.int32)
|
|
||||||
token_ids[num_decodes:] = 0
|
|
||||||
|
|
||||||
# SLOT_MAPPING [batch, 1]
|
|
||||||
# The "slot" is the "physical index" of a token in the KV cache.
|
|
||||||
# Look up the block_idx in the block table (logical<>physical map)
|
|
||||||
# to compute this.
|
|
||||||
block_table_cpu_tensor = self.input_batch.block_table.get_cpu_tensor()
|
|
||||||
block_number = torch.gather(input=block_table_cpu_tensor,
|
|
||||||
dim=1,
|
|
||||||
index=(index // self.block_size))
|
|
||||||
block_offsets = index % self.block_size
|
|
||||||
slot_mapping = block_number * self.block_size + block_offsets
|
|
||||||
# Set an out of range value for the padding tokens so that they
|
|
||||||
# are ignored when inserting into the KV cache.
|
|
||||||
slot_mapping[num_decodes:] = _PAD_SLOT_ID
|
|
||||||
slot_mapping = slot_mapping[:padded_batch_size]
|
|
||||||
slot_mapping = slot_mapping.long()
|
|
||||||
|
|
||||||
# BLOCK_TABLE [batch, max_num_blocks_per_req]
|
|
||||||
block_table = block_table_cpu_tensor[:padded_batch_size]
|
|
||||||
|
|
||||||
# CONTEXT_LENS [batch_size]
|
|
||||||
context_lens = (positions.reshape(-1) + 1)
|
|
||||||
context_lens[num_decodes:] = 0
|
|
||||||
|
|
||||||
# CPU<>TPU sync happens here.
|
|
||||||
return DecodeInputData(num_decodes=num_decodes,
|
|
||||||
token_ids=token_ids.to(self.device),
|
|
||||||
position_ids=positions.to(self.device),
|
|
||||||
attn_metadata=PallasMetadata(
|
|
||||||
num_prefills=0,
|
|
||||||
num_prefill_tokens=0,
|
|
||||||
num_decode_tokens=padded_batch_size,
|
|
||||||
slot_mapping=slot_mapping.to(self.device),
|
|
||||||
multi_modal_placeholder_index_maps=None,
|
|
||||||
enable_kv_scales_calculation=True,
|
|
||||||
block_tables=block_table.to(self.device),
|
|
||||||
context_lens=context_lens.to(self.device),
|
|
||||||
effective_query_lens=None,
|
|
||||||
))
|
|
||||||
|
|
||||||
def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
|
|
||||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
assert total_num_scheduled_tokens > 0
|
assert total_num_scheduled_tokens > 0
|
||||||
num_reqs = self.input_batch.num_reqs
|
num_reqs = self.input_batch.num_reqs
|
||||||
assert num_reqs > 0
|
assert num_reqs > 0
|
||||||
|
|
||||||
num_decodes = num_reqs - self.num_new_reqs
|
block_table_cpu_tensor = self.input_batch.block_table.get_cpu_tensor()
|
||||||
|
|
||||||
# TODO: Resurrect
|
req_ids = []
|
||||||
# OPTIMIZATION: Start copying the block table first.
|
req_indices = []
|
||||||
# This way, we can overlap the copy with the following CPU operations.
|
input_tokens = []
|
||||||
# TODO: Verify this works with TPUs
|
input_positions = []
|
||||||
# self.input_batch.block_table.commit(num_reqs)
|
slot_mapping = []
|
||||||
|
context_lens = []
|
||||||
# Get the number of scheduled tokens for each request.
|
for req_id in self.input_batch.req_ids[:num_reqs]:
|
||||||
# TODO: The Python loop can be slow. Optimize.
|
|
||||||
num_scheduled_tokens = []
|
|
||||||
max_num_scheduled_tokens = 0
|
|
||||||
for idx, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
|
|
||||||
assert req_id is not None
|
assert req_id is not None
|
||||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
req_index = self.input_batch.req_id_to_index[req_id]
|
||||||
num_scheduled_tokens.append(num_tokens)
|
req_state = self.requests[req_id]
|
||||||
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
|
|
||||||
num_tokens)
|
|
||||||
|
|
||||||
# NOTE: Assert that all the decodes are "decodes".
|
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
||||||
if idx < num_decodes:
|
req_id]
|
||||||
assert num_tokens == 1
|
num_computed_tokens = req_state.num_computed_tokens
|
||||||
|
num_prompt_tokens = len(req_state.prompt_token_ids)
|
||||||
|
|
||||||
assert max_num_scheduled_tokens > 0
|
# Detect whether this is a decode
|
||||||
|
if num_computed_tokens < num_prompt_tokens:
|
||||||
|
# This is a prompt => Skip
|
||||||
|
continue
|
||||||
|
|
||||||
return (
|
# This is a decode
|
||||||
self._prepare_prefill_inputs(num_scheduled_tokens),
|
req_ids.append(req_id)
|
||||||
self._prepare_decode_inputs(),
|
req_indices.append(req_index)
|
||||||
|
|
||||||
|
# Seq len
|
||||||
|
seq_len = num_computed_tokens + num_scheduled_tokens
|
||||||
|
|
||||||
|
# Sanity check decode
|
||||||
|
assert num_scheduled_tokens == 1
|
||||||
|
assert seq_len == req_state.num_tokens
|
||||||
|
|
||||||
|
# Input token
|
||||||
|
input_tokens.append([
|
||||||
|
self.input_batch.token_ids_cpu[req_index, num_computed_tokens]
|
||||||
|
])
|
||||||
|
|
||||||
|
# Position
|
||||||
|
input_positions.append([num_computed_tokens])
|
||||||
|
|
||||||
|
# Slot mapping
|
||||||
|
block_number = block_table_cpu_tensor[req_index,
|
||||||
|
num_computed_tokens //
|
||||||
|
self.block_size]
|
||||||
|
block_offset = num_computed_tokens % self.block_size
|
||||||
|
slot_id = block_number * self.block_size + block_offset
|
||||||
|
slot_mapping.append([slot_id])
|
||||||
|
|
||||||
|
# Context len
|
||||||
|
context_lens.append(seq_len)
|
||||||
|
|
||||||
|
# Compute padding
|
||||||
|
batch_size = len(input_tokens)
|
||||||
|
padded_batch_size = _get_padded_batch_size(batch_size)
|
||||||
|
num_padding = padded_batch_size - batch_size
|
||||||
|
|
||||||
|
# Add padding
|
||||||
|
input_tokens.extend([[0]] * num_padding)
|
||||||
|
input_positions.extend([[0]] * num_padding)
|
||||||
|
slot_mapping.extend([[_PAD_SLOT_ID]] * num_padding)
|
||||||
|
context_lens.extend([0] * num_padding)
|
||||||
|
req_indices.extend([0] * num_padding)
|
||||||
|
|
||||||
|
# Create tensors
|
||||||
|
input_tokens_tensor = torch.tensor(input_tokens,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu")
|
||||||
|
input_positions_tensor = torch.tensor(input_positions,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu")
|
||||||
|
slot_mapping_tensor = torch.tensor(slot_mapping,
|
||||||
|
dtype=torch.int64,
|
||||||
|
device="cpu")
|
||||||
|
context_lens_tensor = torch.tensor(context_lens,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu")
|
||||||
|
block_tables_tensor = block_table_cpu_tensor[req_indices]
|
||||||
|
|
||||||
|
# Attn metadata
|
||||||
|
attn_metadata = PallasMetadata(
|
||||||
|
num_prefills=0,
|
||||||
|
num_prefill_tokens=0,
|
||||||
|
num_decode_tokens=padded_batch_size,
|
||||||
|
slot_mapping=slot_mapping_tensor.to(self.device),
|
||||||
|
multi_modal_placeholder_index_maps=None,
|
||||||
|
enable_kv_scales_calculation=True,
|
||||||
|
block_tables=block_tables_tensor.to(self.device),
|
||||||
|
context_lens=context_lens_tensor.to(self.device),
|
||||||
|
effective_query_lens=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return DecodeInputData(
|
||||||
|
req_ids=req_ids,
|
||||||
|
input_tokens=input_tokens_tensor.to(self.device),
|
||||||
|
input_positions=input_positions_tensor.to(self.device),
|
||||||
|
attn_metadata=attn_metadata)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
) -> ModelRunnerOutput:
|
) -> ModelRunnerOutput:
|
||||||
self._update_states(scheduler_output)
|
# Update cached state
|
||||||
|
self.update_states(scheduler_output)
|
||||||
|
|
||||||
# Prepare the decoder inputs.
|
# Prepare inputs
|
||||||
prefill_data, decode_data = self._prepare_inputs(scheduler_output)
|
prompt_data = self._prepare_prompt_inputs(scheduler_output)
|
||||||
|
decode_data = self._prepare_decode_inputs(scheduler_output)
|
||||||
|
|
||||||
|
# Init
|
||||||
num_reqs = self.input_batch.num_reqs
|
num_reqs = self.input_batch.num_reqs
|
||||||
|
assert num_reqs > 0
|
||||||
|
sampled_token_ids_list = [0] * num_reqs
|
||||||
|
|
||||||
######################### DECODES #########################
|
# Run decodes (a single batch)
|
||||||
# Decodes run as one single batch with [padded_batch, 1]
|
if len(decode_data.req_ids) > 0:
|
||||||
sampled_token_ids_list = []
|
# Forward
|
||||||
if decode_data.num_decodes > 0:
|
|
||||||
# FORWARD.
|
|
||||||
with set_forward_context(decode_data.attn_metadata,
|
with set_forward_context(decode_data.attn_metadata,
|
||||||
self.vllm_config):
|
self.vllm_config):
|
||||||
assert self.model is not None
|
assert self.model is not None
|
||||||
selected_token_ids = self.model(decode_data.token_ids,
|
selected_token_ids = self.model(decode_data.input_tokens,
|
||||||
decode_data.position_ids,
|
decode_data.input_positions,
|
||||||
decode_data.attn_metadata,
|
decode_data.attn_metadata,
|
||||||
self.kv_caches)
|
self.kv_caches)
|
||||||
|
|
||||||
# NOTE: TPU<>CPU sync happens here.
|
# Transfer sampled tokens from TPU to CPU
|
||||||
# We need to call .cpu() first to avoid recompilation.
|
selected_token_ids_list = selected_token_ids.cpu().tolist()
|
||||||
token_ids = selected_token_ids.cpu()[:decode_data.num_decodes]
|
|
||||||
sampled_token_ids_list.extend(token_ids.tolist())
|
|
||||||
|
|
||||||
# UPDATE REQUEST STATE.
|
# Update cached state
|
||||||
for i, req_id in enumerate(
|
for i, req_id in enumerate(decode_data.req_ids):
|
||||||
self.input_batch.req_ids[:decode_data.num_decodes]):
|
req_index = self.input_batch.req_id_to_index[req_id]
|
||||||
assert req_id is not None
|
|
||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
|
|
||||||
assert scheduler_output.num_scheduled_tokens[req_id] == 1
|
|
||||||
seq_len = (req_state.num_computed_tokens +
|
seq_len = (req_state.num_computed_tokens +
|
||||||
scheduler_output.num_scheduled_tokens[req_id])
|
scheduler_output.num_scheduled_tokens[req_id])
|
||||||
assert seq_len == req_state.num_tokens
|
|
||||||
|
|
||||||
token_id = sampled_token_ids_list[i]
|
token_id = selected_token_ids_list[i]
|
||||||
self.input_batch.token_ids_cpu[i, seq_len] = token_id
|
|
||||||
self.input_batch.num_tokens[i] += 1
|
self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
|
||||||
|
self.input_batch.num_tokens[req_index] += 1
|
||||||
req_state.output_token_ids.append(token_id)
|
req_state.output_token_ids.append(token_id)
|
||||||
|
|
||||||
######################### PREFILLS #########################
|
sampled_token_ids_list[req_index] = token_id
|
||||||
# Prefills run separately with shape [1, padded_prefill_len],
|
|
||||||
# due to lack of variable length attention kernel so far.
|
|
||||||
for (req_id, prompt_len, token_ids, position_ids,
|
|
||||||
attn_metadata) in prefill_data.zipped():
|
|
||||||
assert req_id is not None
|
|
||||||
|
|
||||||
# FORWARD.
|
# Run each prompt
|
||||||
|
for (req_id, prompt_len, input_tokens, input_positions,
|
||||||
|
attn_metadata) in prompt_data.zipped():
|
||||||
|
assert req_id is not None
|
||||||
|
req_state = self.requests[req_id]
|
||||||
|
req_index = self.input_batch.req_id_to_index[req_id]
|
||||||
|
|
||||||
|
# Forward
|
||||||
with set_forward_context(attn_metadata, self.vllm_config):
|
with set_forward_context(attn_metadata, self.vllm_config):
|
||||||
assert self.model is not None
|
assert self.model is not None
|
||||||
selected_token_ids = self.model(token_ids, position_ids,
|
selected_token_ids = self.model(input_tokens, input_positions,
|
||||||
attn_metadata, self.kv_caches)
|
attn_metadata, self.kv_caches)
|
||||||
|
|
||||||
# NOTE: TPU<>CPU sync happens here.
|
|
||||||
# We need to call .cpu() first to avoid recompilation.
|
|
||||||
token_id = selected_token_ids.cpu()[prompt_len - 1].item()
|
|
||||||
sampled_token_ids_list.append(token_id)
|
|
||||||
req_state = self.requests[req_id]
|
|
||||||
|
|
||||||
assert req_state.num_computed_tokens == 0
|
|
||||||
seq_len = (req_state.num_computed_tokens +
|
seq_len = (req_state.num_computed_tokens +
|
||||||
scheduler_output.num_scheduled_tokens[req_id])
|
scheduler_output.num_scheduled_tokens[req_id])
|
||||||
assert seq_len == req_state.num_tokens
|
if seq_len >= len(req_state.prompt_token_ids):
|
||||||
assert prompt_len == seq_len
|
# Transfer sampled tokens from TPU to CPU
|
||||||
|
token_id = selected_token_ids.cpu()[prompt_len - 1].item()
|
||||||
|
sampled_token_ids_list[req_index] = token_id
|
||||||
|
|
||||||
# UPDATE REQUEST STATE.
|
# Update cached state
|
||||||
req_idx = self.input_batch.req_id_to_index[req_id]
|
self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
|
||||||
self.input_batch.token_ids_cpu[req_idx, seq_len] = token_id
|
self.input_batch.num_tokens[req_index] += 1
|
||||||
self.input_batch.num_tokens[req_idx] += 1
|
req_state.output_token_ids.append(token_id)
|
||||||
req_state.output_token_ids.append(token_id)
|
|
||||||
|
|
||||||
# num_reqs entries should be non-None
|
# Get req_ids
|
||||||
assert all(
|
assert all(
|
||||||
req_id is not None for req_id in
|
req_id is not None for req_id in
|
||||||
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
|
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user