mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-24 02:55:54 +08:00
Chunked prompt works!
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
This commit is contained in:
parent
0bddb6b9a5
commit
61bb55f3d5
@ -213,11 +213,12 @@ class Scheduler:
|
||||
num_new_tokens = self.block_size
|
||||
computed_blocks.pop()
|
||||
|
||||
# TODO: Remove
|
||||
# If chunked prefill is not enabled, then breakout of the loop
|
||||
# when above budget.
|
||||
if (not self.scheduler_config.chunked_prefill_enabled
|
||||
and num_new_tokens > token_budget):
|
||||
break
|
||||
# if (not self.scheduler_config.chunked_prefill_enabled
|
||||
# and num_new_tokens > token_budget):
|
||||
# break
|
||||
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
assert num_new_tokens > 0
|
||||
|
||||
@ -9,23 +9,18 @@ import torch.distributed
|
||||
from vllm.config import CompilationLevel, VllmConfig
|
||||
from vllm.distributed.parallel_state import graph_capture
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.utils import DeviceMemoryProfiler, cdiv
|
||||
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
|
||||
FlashAttentionMetadata)
|
||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.model_runner_base import ExecutionMode, ModelRunnerBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -43,41 +38,9 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
):
|
||||
super().__init__(vllm_config, device)
|
||||
|
||||
# Persistent batch.
|
||||
self.input_batch = InputBatch(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.max_model_len,
|
||||
max_num_blocks_per_req=self.max_num_blocks_per_req,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=self.model_config.get_vocab_size(),
|
||||
)
|
||||
|
||||
# Request states.
|
||||
self.requests: Dict[str, CachedRequestState] = {}
|
||||
|
||||
# KV caches for forward pass
|
||||
self.kv_caches: List[torch.Tensor] = []
|
||||
|
||||
# Multi-modal data support
|
||||
self.input_registry = INPUT_REGISTRY
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
|
||||
# NOTE: Initialized input mapper is only used for processing dummy
|
||||
# multimodal data into multimodal kwargs for GPU memory profiling.
|
||||
self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config)
|
||||
self.mm_input_mapper_profiling.use_cache = False
|
||||
|
||||
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
|
||||
model_config=self.model_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
)
|
||||
self.max_num_encoder_input_tokens = encoder_compute_budget
|
||||
self.encoder_cache_size = encoder_cache_size
|
||||
|
||||
# req_id -> (input_id -> encoder_output)
|
||||
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
|
||||
|
||||
self.use_cuda_graph = (self.vllm_config.compilation_config.level
|
||||
== CompilationLevel.PIECEWISE
|
||||
and not self.model_config.enforce_eager)
|
||||
@ -160,132 +123,6 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
pin_memory=self.pin_memory)
|
||||
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
||||
|
||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
# Remove stopped requests from the cached states.
|
||||
# Keep the states of the pre-empted requests.
|
||||
for req_id in scheduler_output.finished_req_ids:
|
||||
self.requests.pop(req_id, None)
|
||||
self.encoder_cache.pop(req_id, None)
|
||||
|
||||
# Free the cached encoder outputs.
|
||||
for req_id, input_id in scheduler_output.free_encoder_input_ids:
|
||||
encoder_outputs = self.encoder_cache.get(req_id)
|
||||
if encoder_outputs is not None:
|
||||
encoder_outputs.pop(input_id, None)
|
||||
if not encoder_outputs:
|
||||
self.encoder_cache.pop(req_id, None)
|
||||
|
||||
# Remove the requests from the persistent batch.
|
||||
stopped_req_ids = set().union(
|
||||
scheduler_output.preempted_req_ids,
|
||||
scheduler_output.finished_req_ids,
|
||||
)
|
||||
removed_req_indices: List[int] = []
|
||||
for req_id in stopped_req_ids:
|
||||
req_index = self.input_batch.remove_request(req_id)
|
||||
if req_index is not None:
|
||||
removed_req_indices.append(req_index)
|
||||
|
||||
# Update the states of the running requests.
|
||||
for req_data in scheduler_output.scheduled_running_reqs:
|
||||
req_id = req_data.req_id
|
||||
req_state = self.requests[req_id]
|
||||
req_index = self.input_batch.req_id_to_index[req_id]
|
||||
|
||||
# Update the num_computed_tokens.
|
||||
req_state.num_computed_tokens = req_data.num_computed_tokens
|
||||
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
||||
req_data.num_computed_tokens)
|
||||
|
||||
# Update the block table.
|
||||
num_new_blocks = len(req_data.new_block_ids)
|
||||
if num_new_blocks == 0:
|
||||
continue
|
||||
start_index = len(req_state.block_ids)
|
||||
req_state.block_ids.extend(req_data.new_block_ids)
|
||||
self.input_batch.block_table.append_row(req_index, start_index,
|
||||
req_data.new_block_ids)
|
||||
|
||||
req_ids_to_add: List[str] = []
|
||||
# Add new requests to the cached states.
|
||||
for new_req_data in scheduler_output.scheduled_new_reqs:
|
||||
req_id = new_req_data.req_id
|
||||
sampling_params = new_req_data.sampling_params
|
||||
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
||||
generator = torch.Generator(device=self.device)
|
||||
generator.manual_seed(sampling_params.seed)
|
||||
else:
|
||||
generator = None
|
||||
|
||||
self.requests[req_id] = CachedRequestState(
|
||||
req_id=req_id,
|
||||
prompt_token_ids=new_req_data.prompt_token_ids,
|
||||
prompt=new_req_data.prompt,
|
||||
mm_inputs=new_req_data.mm_inputs,
|
||||
mm_positions=new_req_data.mm_positions,
|
||||
sampling_params=sampling_params,
|
||||
generator=generator,
|
||||
block_ids=new_req_data.block_ids,
|
||||
num_computed_tokens=new_req_data.num_computed_tokens,
|
||||
output_token_ids=[],
|
||||
)
|
||||
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
if self.model_config.uses_mrope:
|
||||
image_grid_thw = []
|
||||
video_grid_thw = []
|
||||
for mm_input in self.requests[req_id].mm_inputs:
|
||||
if mm_input.get("image_grid_thw") is not None:
|
||||
image_grid_thw.extend(
|
||||
mm_input["image_grid_thw"].tolist())
|
||||
if mm_input.get("video_grid_thw") is not None:
|
||||
video_grid_thw.extend(
|
||||
mm_input["video_grid_thw"].tolist())
|
||||
|
||||
hf_config = self.model_config.hf_config
|
||||
|
||||
self.requests[req_id].mrope_positions, \
|
||||
self.requests[req_id].mrope_position_delta = \
|
||||
MRotaryEmbedding.get_input_positions_tensor(
|
||||
self.requests[req_id].prompt_token_ids,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
image_token_id=hf_config.image_token_id,
|
||||
video_token_id=hf_config.video_token_id,
|
||||
vision_start_token_id=hf_config.vision_start_token_id,
|
||||
vision_end_token_id=hf_config.vision_end_token_id,
|
||||
spatial_merge_size=hf_config.vision_config.
|
||||
spatial_merge_size,
|
||||
)
|
||||
|
||||
req_ids_to_add.append(req_id)
|
||||
|
||||
# Update the cached states of the resumed requests.
|
||||
for res_req_data in scheduler_output.scheduled_resumed_reqs:
|
||||
req_id = res_req_data.req_id
|
||||
req_state = self.requests[req_id]
|
||||
|
||||
req_state.block_ids = res_req_data.block_ids
|
||||
req_state.num_computed_tokens = res_req_data.num_computed_tokens
|
||||
req_ids_to_add.append(req_id)
|
||||
|
||||
# Add the new or resumed requests to the persistent batch.
|
||||
# The smaller empty indices are filled first.
|
||||
removed_req_indices = sorted(removed_req_indices, reverse=True)
|
||||
for req_id in req_ids_to_add:
|
||||
req_state = self.requests[req_id]
|
||||
if removed_req_indices:
|
||||
# Fill the empty index.
|
||||
req_index = removed_req_indices.pop()
|
||||
else:
|
||||
# Append to the end.
|
||||
req_index = None
|
||||
self.input_batch.add_request(req_state, req_index)
|
||||
|
||||
# Condense the batched states if there are empty indices.
|
||||
if removed_req_indices:
|
||||
self.input_batch.condense(removed_req_indices)
|
||||
|
||||
def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens > 0
|
||||
@ -665,7 +502,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
) -> ModelRunnerOutput:
|
||||
assert self.model is not None
|
||||
|
||||
self._update_states(scheduler_output)
|
||||
self.update_states(scheduler_output)
|
||||
|
||||
if self.is_multimodal_model:
|
||||
# Run the multimodal encoder if any.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import enum
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@ -8,11 +8,18 @@ import torch.nn as nn
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
|
||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec)
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.scheduler import SchedulerOutput
|
||||
@ -75,6 +82,164 @@ class ModelRunnerBase:
|
||||
|
||||
self.model: Optional[nn.Module] = None
|
||||
|
||||
# Persistent batch.
|
||||
self.input_batch = InputBatch(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.max_model_len,
|
||||
max_num_blocks_per_req=self.max_num_blocks_per_req,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=self.model_config.get_vocab_size(),
|
||||
)
|
||||
|
||||
# Request states.
|
||||
self.requests: Dict[str, CachedRequestState] = {}
|
||||
|
||||
# Multi-modal data support
|
||||
self.input_registry = INPUT_REGISTRY
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
|
||||
# NOTE: Initialized input mapper is only used for processing dummy
|
||||
# multimodal data into multimodal kwargs for GPU memory profiling.
|
||||
self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config)
|
||||
self.mm_input_mapper_profiling.use_cache = False
|
||||
|
||||
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
|
||||
model_config=self.model_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
)
|
||||
self.max_num_encoder_input_tokens = encoder_compute_budget
|
||||
self.encoder_cache_size = encoder_cache_size
|
||||
|
||||
# req_id -> (input_id -> encoder_output)
|
||||
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
|
||||
|
||||
def update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
# Remove stopped requests from the cached states.
|
||||
# Keep the states of the pre-empted requests.
|
||||
for req_id in scheduler_output.finished_req_ids:
|
||||
self.requests.pop(req_id, None)
|
||||
self.encoder_cache.pop(req_id, None)
|
||||
|
||||
# Free the cached encoder outputs.
|
||||
for req_id, input_id in scheduler_output.free_encoder_input_ids:
|
||||
encoder_outputs = self.encoder_cache.get(req_id)
|
||||
if encoder_outputs is not None:
|
||||
encoder_outputs.pop(input_id, None)
|
||||
if not encoder_outputs:
|
||||
self.encoder_cache.pop(req_id, None)
|
||||
|
||||
# Remove the requests from the persistent batch.
|
||||
stopped_req_ids = set().union(
|
||||
scheduler_output.preempted_req_ids,
|
||||
scheduler_output.finished_req_ids,
|
||||
)
|
||||
removed_req_indices: List[int] = []
|
||||
for req_id in stopped_req_ids:
|
||||
req_index = self.input_batch.remove_request(req_id)
|
||||
if req_index is not None:
|
||||
removed_req_indices.append(req_index)
|
||||
|
||||
# Update the states of the running requests.
|
||||
for req_data in scheduler_output.scheduled_running_reqs:
|
||||
req_id = req_data.req_id
|
||||
req_state = self.requests[req_id]
|
||||
req_index = self.input_batch.req_id_to_index[req_id]
|
||||
|
||||
# Update the num_computed_tokens.
|
||||
req_state.num_computed_tokens = req_data.num_computed_tokens
|
||||
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
||||
req_data.num_computed_tokens)
|
||||
|
||||
# Update the block table.
|
||||
num_new_blocks = len(req_data.new_block_ids)
|
||||
if num_new_blocks == 0:
|
||||
continue
|
||||
start_index = len(req_state.block_ids)
|
||||
req_state.block_ids.extend(req_data.new_block_ids)
|
||||
self.input_batch.block_table.append_row(req_index, start_index,
|
||||
req_data.new_block_ids)
|
||||
|
||||
req_ids_to_add: List[str] = []
|
||||
# Add new requests to the cached states.
|
||||
for new_req_data in scheduler_output.scheduled_new_reqs:
|
||||
req_id = new_req_data.req_id
|
||||
sampling_params = new_req_data.sampling_params
|
||||
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
||||
generator = torch.Generator(device=self.device)
|
||||
generator.manual_seed(sampling_params.seed)
|
||||
else:
|
||||
generator = None
|
||||
|
||||
self.requests[req_id] = CachedRequestState(
|
||||
req_id=req_id,
|
||||
prompt_token_ids=new_req_data.prompt_token_ids,
|
||||
prompt=new_req_data.prompt,
|
||||
mm_inputs=new_req_data.mm_inputs,
|
||||
mm_positions=new_req_data.mm_positions,
|
||||
sampling_params=sampling_params,
|
||||
generator=generator,
|
||||
block_ids=new_req_data.block_ids,
|
||||
num_computed_tokens=new_req_data.num_computed_tokens,
|
||||
output_token_ids=[],
|
||||
)
|
||||
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
if self.model_config.uses_mrope:
|
||||
image_grid_thw = []
|
||||
video_grid_thw = []
|
||||
for mm_input in self.requests[req_id].mm_inputs:
|
||||
if mm_input.get("image_grid_thw") is not None:
|
||||
image_grid_thw.extend(
|
||||
mm_input["image_grid_thw"].tolist())
|
||||
if mm_input.get("video_grid_thw") is not None:
|
||||
video_grid_thw.extend(
|
||||
mm_input["video_grid_thw"].tolist())
|
||||
|
||||
hf_config = self.model_config.hf_config
|
||||
|
||||
self.requests[req_id].mrope_positions, \
|
||||
self.requests[req_id].mrope_position_delta = \
|
||||
MRotaryEmbedding.get_input_positions_tensor(
|
||||
self.requests[req_id].prompt_token_ids,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
image_token_id=hf_config.image_token_id,
|
||||
video_token_id=hf_config.video_token_id,
|
||||
vision_start_token_id=hf_config.vision_start_token_id,
|
||||
vision_end_token_id=hf_config.vision_end_token_id,
|
||||
spatial_merge_size=hf_config.vision_config.
|
||||
spatial_merge_size,
|
||||
)
|
||||
|
||||
req_ids_to_add.append(req_id)
|
||||
|
||||
# Update the cached states of the resumed requests.
|
||||
for res_req_data in scheduler_output.scheduled_resumed_reqs:
|
||||
req_id = res_req_data.req_id
|
||||
req_state = self.requests[req_id]
|
||||
|
||||
req_state.block_ids = res_req_data.block_ids
|
||||
req_state.num_computed_tokens = res_req_data.num_computed_tokens
|
||||
req_ids_to_add.append(req_id)
|
||||
|
||||
# Add the new or resumed requests to the persistent batch.
|
||||
# The smaller empty indices are filled first.
|
||||
removed_req_indices = sorted(removed_req_indices, reverse=True)
|
||||
for req_id in req_ids_to_add:
|
||||
req_state = self.requests[req_id]
|
||||
if removed_req_indices:
|
||||
# Fill the empty index.
|
||||
req_index = removed_req_indices.pop()
|
||||
else:
|
||||
# Append to the end.
|
||||
req_index = None
|
||||
self.input_batch.add_request(req_state, req_index)
|
||||
|
||||
# Condense the batched states if there are empty indices.
|
||||
if removed_req_indices:
|
||||
self.input_batch.condense(removed_req_indices)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
assert self.model is not None
|
||||
return self.model
|
||||
|
||||
@ -15,7 +15,6 @@ from vllm.config import VllmConfig
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
|
||||
PallasMetadata)
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
|
||||
@ -40,25 +39,24 @@ _MAX_NUM_SAMPLES = 128
|
||||
|
||||
|
||||
@dataclass
|
||||
class PrefillInputData:
|
||||
class PromptInputData:
|
||||
|
||||
request_ids: List
|
||||
req_ids: List
|
||||
prompt_lens: List
|
||||
token_ids: List
|
||||
position_ids: List
|
||||
input_tokens: List
|
||||
input_positions: List
|
||||
attn_metadata: List
|
||||
|
||||
def zipped(self):
|
||||
return zip(self.request_ids, self.prompt_lens, self.token_ids,
|
||||
self.position_ids, self.attn_metadata)
|
||||
return zip(self.req_ids, self.prompt_lens, self.input_tokens,
|
||||
self.input_positions, self.attn_metadata)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecodeInputData:
|
||||
|
||||
num_decodes: int
|
||||
token_ids: Optional[torch.Tensor] = None
|
||||
position_ids: Optional[torch.Tensor] = None
|
||||
req_ids: List
|
||||
input_tokens: Optional[torch.Tensor] = None
|
||||
input_positions: Optional[torch.Tensor] = None
|
||||
attn_metadata: Optional[PallasMetadata] = None
|
||||
|
||||
|
||||
@ -88,158 +86,105 @@ class TPUModelRunner(ModelRunnerBase):
|
||||
self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = []
|
||||
|
||||
# Used to initialize positions for the individual prefills
|
||||
self.prefill_positions = torch.tensor(range(self.max_model_len),
|
||||
device="cpu",
|
||||
dtype=torch.int32).reshape(
|
||||
1, -1)
|
||||
self.prefill_input_positions = torch.tensor(range(self.max_model_len),
|
||||
device="cpu",
|
||||
dtype=torch.int32).reshape(
|
||||
1, -1)
|
||||
|
||||
# Used to indicate how many prefills there are for each scheduler
|
||||
# iteration
|
||||
self.num_new_reqs: int = 0
|
||||
|
||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
# Remove stopped requests from the cached states.
|
||||
# Keep the states of the pre-empted requests.
|
||||
for req_id in scheduler_output.finished_req_ids:
|
||||
self.requests.pop(req_id, None)
|
||||
|
||||
# Remove the requests from the persistent batch.
|
||||
stopped_req_ids = set().union(
|
||||
scheduler_output.preempted_req_ids,
|
||||
scheduler_output.finished_req_ids,
|
||||
)
|
||||
removed_req_indices: List[int] = []
|
||||
for req_id in stopped_req_ids:
|
||||
req_index = self.input_batch.remove_request(req_id)
|
||||
if req_index is not None:
|
||||
removed_req_indices.append(req_index)
|
||||
|
||||
# Update the states of the running requests.
|
||||
for req_data in scheduler_output.scheduled_running_reqs:
|
||||
req_id = req_data.req_id
|
||||
req_state = self.requests[req_id]
|
||||
req_index = self.input_batch.req_id_to_index[req_id]
|
||||
|
||||
# Update the num_computed_tokens.
|
||||
req_state.num_computed_tokens = req_data.num_computed_tokens
|
||||
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
||||
req_data.num_computed_tokens)
|
||||
|
||||
# Update the block table.
|
||||
num_new_blocks = len(req_data.new_block_ids)
|
||||
if num_new_blocks == 0:
|
||||
continue
|
||||
start_index = len(req_state.block_ids)
|
||||
req_state.block_ids.extend(req_data.new_block_ids)
|
||||
self.input_batch.block_table.append_row(req_index, start_index,
|
||||
req_data.new_block_ids)
|
||||
|
||||
req_ids_to_add: List[str] = []
|
||||
# Add new requests to the cached states.
|
||||
for new_req_data in scheduler_output.scheduled_new_reqs:
|
||||
req_id = new_req_data.req_id
|
||||
sampling_params = new_req_data.sampling_params
|
||||
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
||||
generator = torch.Generator(device=self.device)
|
||||
generator.manual_seed(sampling_params.seed)
|
||||
else:
|
||||
generator = None
|
||||
|
||||
self.requests[req_id] = CachedRequestState(
|
||||
req_id=req_id,
|
||||
prompt_token_ids=new_req_data.prompt_token_ids,
|
||||
prompt=new_req_data.prompt,
|
||||
mm_inputs=new_req_data.mm_inputs,
|
||||
mm_positions=new_req_data.mm_positions,
|
||||
sampling_params=sampling_params,
|
||||
generator=generator,
|
||||
block_ids=new_req_data.block_ids,
|
||||
num_computed_tokens=new_req_data.num_computed_tokens,
|
||||
output_token_ids=[],
|
||||
)
|
||||
req_ids_to_add.append(req_id)
|
||||
|
||||
# Update the cached states of the resumed requests.
|
||||
for res_req_data in scheduler_output.scheduled_resumed_reqs:
|
||||
req_id = res_req_data.req_id
|
||||
req_state = self.requests[req_id]
|
||||
|
||||
req_state.block_ids = res_req_data.block_ids
|
||||
req_state.num_computed_tokens = res_req_data.num_computed_tokens
|
||||
req_ids_to_add.append(req_id)
|
||||
|
||||
# For TPU, we keep all of the decode requests before the
|
||||
# prefill requests in the batch sequence.
|
||||
# 1. First condense, so all decodes move to start
|
||||
# 2. Then add new prefills to the end of the batch
|
||||
removed_req_indices = sorted(removed_req_indices, reverse=True)
|
||||
if removed_req_indices:
|
||||
self.input_batch.condense(removed_req_indices)
|
||||
|
||||
for req_id in req_ids_to_add:
|
||||
req_state = self.requests[req_id]
|
||||
self.input_batch.add_request(req_state, None) # Append last
|
||||
|
||||
self.num_new_reqs = len(req_ids_to_add)
|
||||
|
||||
def _prepare_prefill_inputs(
|
||||
def _prepare_prompt_inputs(
|
||||
self,
|
||||
num_scheduled_tokens: List[int],
|
||||
) -> PrefillInputData:
|
||||
# Each prefill run separately with shape [1, padded_prompt_len].
|
||||
# So we create lists that will be used in execute_model().
|
||||
|
||||
prefill_request_ids = []
|
||||
prefill_prompt_lens = []
|
||||
prefill_token_ids = []
|
||||
prefill_position_ids = []
|
||||
prefill_attn_metadata = []
|
||||
|
||||
# DECODES are the first num_decodes REQUESTS.
|
||||
# PREFILLS are the next num_reqs - num_decodes REQUESTS.
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> PromptInputData:
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens > 0
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
num_decodes = num_reqs - self.num_new_reqs
|
||||
for idx in range(num_decodes, num_reqs):
|
||||
req_id = self.input_batch.req_ids[idx]
|
||||
prefill_request_ids.append(req_id)
|
||||
assert num_reqs > 0
|
||||
|
||||
prompt_len = num_scheduled_tokens[idx]
|
||||
prefill_prompt_lens.append(prompt_len)
|
||||
# OPTIMIZATION: Start copying the block table first.
|
||||
# This way, we can overlap the copy with the following CPU operations.
|
||||
self.input_batch.block_table.commit(num_reqs)
|
||||
|
||||
# STATIC SHAPE: prefills are padded to the next power of 2.
|
||||
req_ids = []
|
||||
prompt_lens = []
|
||||
input_tokens_list = []
|
||||
input_positions_list = []
|
||||
attn_metadata_list = []
|
||||
for req_id in self.input_batch.req_ids[:num_reqs]:
|
||||
assert req_id is not None
|
||||
req_index = self.input_batch.req_id_to_index[req_id]
|
||||
req_state = self.requests[req_id]
|
||||
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
||||
req_id]
|
||||
num_computed_tokens = req_state.num_computed_tokens
|
||||
num_prompt_tokens = len(req_state.prompt_token_ids)
|
||||
|
||||
# Detect whether this is a prompt (can be full or chunked)
|
||||
if num_computed_tokens >= num_prompt_tokens:
|
||||
# This is a decode => Skip
|
||||
continue
|
||||
|
||||
# This is a prompt
|
||||
req_ids.append(req_id)
|
||||
|
||||
# Prompt len
|
||||
prompt_len = num_scheduled_tokens
|
||||
prompt_lens.append(prompt_len)
|
||||
padded_prompt_len = _get_padded_prefill_len(prompt_len)
|
||||
assert padded_prompt_len <= self.max_model_len
|
||||
|
||||
# TOKEN_IDS.
|
||||
token_ids = torch.from_numpy(self.input_batch.token_ids_cpu[
|
||||
idx, :padded_prompt_len].reshape(1, -1))
|
||||
token_ids[:, prompt_len:] = 0
|
||||
prefill_token_ids.append(token_ids.to(self.device))
|
||||
# Seq len
|
||||
seq_len = num_computed_tokens + prompt_len
|
||||
padded_seq_len = num_computed_tokens + padded_prompt_len
|
||||
|
||||
# POSITIONS.
|
||||
positions = self.prefill_positions[:, :padded_prompt_len].clone()
|
||||
positions[:, prompt_len:] = 0
|
||||
prefill_position_ids.append(positions.to(self.device))
|
||||
# Input tokens
|
||||
input_tokens = torch.from_numpy(self.input_batch.token_ids_cpu[
|
||||
req_index, num_computed_tokens:padded_seq_len].reshape(1, -1))
|
||||
input_tokens[:, prompt_len:] = 0
|
||||
input_tokens_list.append(input_tokens.to(self.device))
|
||||
|
||||
# SLOT_MAPPING.
|
||||
# The "slot" is the "physical index" of a token in the KV cache.
|
||||
# Look up the block_idx in the block table (logical<>physical map)
|
||||
# to compute this.
|
||||
# Input positions
|
||||
input_positions = self.prefill_input_positions[:,
|
||||
num_computed_tokens:
|
||||
padded_seq_len].clone(
|
||||
)
|
||||
input_positions[:, prompt_len:] = 0
|
||||
input_positions_list.append(input_positions.to(self.device))
|
||||
|
||||
# Slot mapping
|
||||
block_table_cpu_tensor = \
|
||||
self.input_batch.block_table.get_cpu_tensor()
|
||||
|
||||
block_numbers = block_table_cpu_tensor[idx, positions //
|
||||
block_numbers = block_table_cpu_tensor[req_index,
|
||||
input_positions //
|
||||
self.block_size].reshape(
|
||||
1, -1)
|
||||
|
||||
block_offsets = positions % self.block_size
|
||||
block_offsets = input_positions % self.block_size
|
||||
slot_mapping = block_numbers * self.block_size + block_offsets
|
||||
# Set an out of range value for the padding tokens so that they
|
||||
# are ignored when inserting into the KV cache.
|
||||
slot_mapping[:, prompt_len:] = _PAD_SLOT_ID
|
||||
slot_mapping = slot_mapping.long()
|
||||
|
||||
prefill_attn_metadata.append(
|
||||
# Block table
|
||||
block_table = None
|
||||
if num_computed_tokens > 0:
|
||||
block_table = self.input_batch.block_table.get_device_tensor()
|
||||
block_table = block_table[req_index].unsqueeze(0)
|
||||
|
||||
# Context len
|
||||
context_len = 0
|
||||
if num_computed_tokens > 0:
|
||||
context_len = seq_len
|
||||
context_lens = torch.tensor([context_len],
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
|
||||
# Effective query len
|
||||
effective_query_lens = torch.tensor([prompt_len],
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
|
||||
# Attn metadata
|
||||
attn_metadata_list.append(
|
||||
PallasMetadata(
|
||||
num_prefills=1,
|
||||
num_prefill_tokens=0, # NOTE: This is not used.
|
||||
@ -247,208 +192,200 @@ class TPUModelRunner(ModelRunnerBase):
|
||||
slot_mapping=slot_mapping.to(self.device),
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
block_tables=None,
|
||||
context_lens=None,
|
||||
effective_query_lens=None,
|
||||
block_tables=block_table,
|
||||
context_lens=context_lens.to(self.device),
|
||||
effective_query_lens=effective_query_lens.to(self.device),
|
||||
))
|
||||
|
||||
return PrefillInputData(
|
||||
request_ids=prefill_request_ids,
|
||||
prompt_lens=prefill_prompt_lens,
|
||||
token_ids=prefill_token_ids,
|
||||
position_ids=prefill_position_ids,
|
||||
attn_metadata=prefill_attn_metadata,
|
||||
return PromptInputData(
|
||||
req_ids=req_ids,
|
||||
prompt_lens=prompt_lens,
|
||||
input_tokens=input_tokens_list,
|
||||
input_positions=input_positions_list,
|
||||
attn_metadata=attn_metadata_list,
|
||||
)
|
||||
|
||||
def _prepare_decode_inputs(self) -> DecodeInputData:
|
||||
# Decodes run as one single padded batch with shape [batch, 1]
|
||||
#
|
||||
# We need to set _PAD_SLOT_ID for the padding tokens in the
|
||||
# slot_mapping, such that the attention KV cache insertion
|
||||
# logic knows to ignore those indices. Otherwise, the
|
||||
# padding data can be dummy since we have a causal mask.
|
||||
|
||||
# DECODES are the first num_decodes REQUESTS.
|
||||
# PREFILLS are the next num_reqs - num_decodes REQUESTS.
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
num_decodes = num_reqs - self.num_new_reqs
|
||||
|
||||
if num_decodes == 0:
|
||||
return DecodeInputData(num_decodes=0)
|
||||
|
||||
# PAD FOR STATIC SHAPES.
|
||||
padded_batch_size = _get_padded_batch_size(num_decodes)
|
||||
|
||||
# POSITIONS. [batch, 1]
|
||||
# We slice at the end, since we use the positions for gathering.
|
||||
positions = torch.from_numpy(
|
||||
self.input_batch.num_computed_tokens_cpu.reshape(-1, 1))
|
||||
index = positions.to(torch.int64)
|
||||
index[num_decodes:] = 0
|
||||
positions = positions[:padded_batch_size]
|
||||
positions[num_decodes:] = 0
|
||||
|
||||
# TOKEN_IDS. [batch, 1]
|
||||
token_ids = torch.gather(
|
||||
input=torch.from_numpy(self.input_batch.token_ids_cpu),
|
||||
dim=1,
|
||||
index=index,
|
||||
)[:padded_batch_size].to(torch.int32)
|
||||
token_ids[num_decodes:] = 0
|
||||
|
||||
# SLOT_MAPPING [batch, 1]
|
||||
# The "slot" is the "physical index" of a token in the KV cache.
|
||||
# Look up the block_idx in the block table (logical<>physical map)
|
||||
# to compute this.
|
||||
block_table_cpu_tensor = self.input_batch.block_table.get_cpu_tensor()
|
||||
block_number = torch.gather(input=block_table_cpu_tensor,
|
||||
dim=1,
|
||||
index=(index // self.block_size))
|
||||
block_offsets = index % self.block_size
|
||||
slot_mapping = block_number * self.block_size + block_offsets
|
||||
# Set an out of range value for the padding tokens so that they
|
||||
# are ignored when inserting into the KV cache.
|
||||
slot_mapping[num_decodes:] = _PAD_SLOT_ID
|
||||
slot_mapping = slot_mapping[:padded_batch_size]
|
||||
slot_mapping = slot_mapping.long()
|
||||
|
||||
# BLOCK_TABLE [batch, max_num_blocks_per_req]
|
||||
block_table = block_table_cpu_tensor[:padded_batch_size]
|
||||
|
||||
# CONTEXT_LENS [batch_size]
|
||||
context_lens = (positions.reshape(-1) + 1)
|
||||
context_lens[num_decodes:] = 0
|
||||
|
||||
# CPU<>TPU sync happens here.
|
||||
return DecodeInputData(num_decodes=num_decodes,
|
||||
token_ids=token_ids.to(self.device),
|
||||
position_ids=positions.to(self.device),
|
||||
attn_metadata=PallasMetadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=padded_batch_size,
|
||||
slot_mapping=slot_mapping.to(self.device),
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
block_tables=block_table.to(self.device),
|
||||
context_lens=context_lens.to(self.device),
|
||||
effective_query_lens=None,
|
||||
))
|
||||
|
||||
def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
|
||||
def _prepare_decode_inputs(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> DecodeInputData:
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens > 0
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
assert num_reqs > 0
|
||||
|
||||
num_decodes = num_reqs - self.num_new_reqs
|
||||
block_table_cpu_tensor = self.input_batch.block_table.get_cpu_tensor()
|
||||
|
||||
# TODO: Resurrect
|
||||
# OPTIMIZATION: Start copying the block table first.
|
||||
# This way, we can overlap the copy with the following CPU operations.
|
||||
# TODO: Verify this works with TPUs
|
||||
# self.input_batch.block_table.commit(num_reqs)
|
||||
|
||||
# Get the number of scheduled tokens for each request.
|
||||
# TODO: The Python loop can be slow. Optimize.
|
||||
num_scheduled_tokens = []
|
||||
max_num_scheduled_tokens = 0
|
||||
for idx, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
|
||||
req_ids = []
|
||||
req_indices = []
|
||||
input_tokens = []
|
||||
input_positions = []
|
||||
slot_mapping = []
|
||||
context_lens = []
|
||||
for req_id in self.input_batch.req_ids[:num_reqs]:
|
||||
assert req_id is not None
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
num_scheduled_tokens.append(num_tokens)
|
||||
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
|
||||
num_tokens)
|
||||
req_index = self.input_batch.req_id_to_index[req_id]
|
||||
req_state = self.requests[req_id]
|
||||
|
||||
# NOTE: Assert that all the decodes are "decodes".
|
||||
if idx < num_decodes:
|
||||
assert num_tokens == 1
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
||||
req_id]
|
||||
num_computed_tokens = req_state.num_computed_tokens
|
||||
num_prompt_tokens = len(req_state.prompt_token_ids)
|
||||
|
||||
assert max_num_scheduled_tokens > 0
|
||||
# Detect whether this is a decode
|
||||
if num_computed_tokens < num_prompt_tokens:
|
||||
# This is a prompt => Skip
|
||||
continue
|
||||
|
||||
return (
|
||||
self._prepare_prefill_inputs(num_scheduled_tokens),
|
||||
self._prepare_decode_inputs(),
|
||||
# This is a decode
|
||||
req_ids.append(req_id)
|
||||
req_indices.append(req_index)
|
||||
|
||||
# Seq len
|
||||
seq_len = num_computed_tokens + num_scheduled_tokens
|
||||
|
||||
# Sanity check decode
|
||||
assert num_scheduled_tokens == 1
|
||||
assert seq_len == req_state.num_tokens
|
||||
|
||||
# Input token
|
||||
input_tokens.append([
|
||||
self.input_batch.token_ids_cpu[req_index, num_computed_tokens]
|
||||
])
|
||||
|
||||
# Position
|
||||
input_positions.append([num_computed_tokens])
|
||||
|
||||
# Slot mapping
|
||||
block_number = block_table_cpu_tensor[req_index,
|
||||
num_computed_tokens //
|
||||
self.block_size]
|
||||
block_offset = num_computed_tokens % self.block_size
|
||||
slot_id = block_number * self.block_size + block_offset
|
||||
slot_mapping.append([slot_id])
|
||||
|
||||
# Context len
|
||||
context_lens.append(seq_len)
|
||||
|
||||
# Compute padding
|
||||
batch_size = len(input_tokens)
|
||||
padded_batch_size = _get_padded_batch_size(batch_size)
|
||||
num_padding = padded_batch_size - batch_size
|
||||
|
||||
# Add padding
|
||||
input_tokens.extend([[0]] * num_padding)
|
||||
input_positions.extend([[0]] * num_padding)
|
||||
slot_mapping.extend([[_PAD_SLOT_ID]] * num_padding)
|
||||
context_lens.extend([0] * num_padding)
|
||||
req_indices.extend([0] * num_padding)
|
||||
|
||||
# Create tensors
|
||||
input_tokens_tensor = torch.tensor(input_tokens,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
input_positions_tensor = torch.tensor(input_positions,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
slot_mapping_tensor = torch.tensor(slot_mapping,
|
||||
dtype=torch.int64,
|
||||
device="cpu")
|
||||
context_lens_tensor = torch.tensor(context_lens,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
block_tables_tensor = block_table_cpu_tensor[req_indices]
|
||||
|
||||
# Attn metadata
|
||||
attn_metadata = PallasMetadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=padded_batch_size,
|
||||
slot_mapping=slot_mapping_tensor.to(self.device),
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
block_tables=block_tables_tensor.to(self.device),
|
||||
context_lens=context_lens_tensor.to(self.device),
|
||||
effective_query_lens=None,
|
||||
)
|
||||
|
||||
return DecodeInputData(
|
||||
req_ids=req_ids,
|
||||
input_tokens=input_tokens_tensor.to(self.device),
|
||||
input_positions=input_positions_tensor.to(self.device),
|
||||
attn_metadata=attn_metadata)
|
||||
|
||||
@torch.no_grad()
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> ModelRunnerOutput:
|
||||
self._update_states(scheduler_output)
|
||||
# Update cached state
|
||||
self.update_states(scheduler_output)
|
||||
|
||||
# Prepare the decoder inputs.
|
||||
prefill_data, decode_data = self._prepare_inputs(scheduler_output)
|
||||
# Prepare inputs
|
||||
prompt_data = self._prepare_prompt_inputs(scheduler_output)
|
||||
decode_data = self._prepare_decode_inputs(scheduler_output)
|
||||
|
||||
# Init
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
assert num_reqs > 0
|
||||
sampled_token_ids_list = [0] * num_reqs
|
||||
|
||||
######################### DECODES #########################
|
||||
# Decodes run as one single batch with [padded_batch, 1]
|
||||
sampled_token_ids_list = []
|
||||
if decode_data.num_decodes > 0:
|
||||
# FORWARD.
|
||||
# Run decodes (a single batch)
|
||||
if len(decode_data.req_ids) > 0:
|
||||
# Forward
|
||||
with set_forward_context(decode_data.attn_metadata,
|
||||
self.vllm_config):
|
||||
assert self.model is not None
|
||||
selected_token_ids = self.model(decode_data.token_ids,
|
||||
decode_data.position_ids,
|
||||
selected_token_ids = self.model(decode_data.input_tokens,
|
||||
decode_data.input_positions,
|
||||
decode_data.attn_metadata,
|
||||
self.kv_caches)
|
||||
|
||||
# NOTE: TPU<>CPU sync happens here.
|
||||
# We need to call .cpu() first to avoid recompilation.
|
||||
token_ids = selected_token_ids.cpu()[:decode_data.num_decodes]
|
||||
sampled_token_ids_list.extend(token_ids.tolist())
|
||||
# Transfer sampled tokens from TPU to CPU
|
||||
selected_token_ids_list = selected_token_ids.cpu().tolist()
|
||||
|
||||
# UPDATE REQUEST STATE.
|
||||
for i, req_id in enumerate(
|
||||
self.input_batch.req_ids[:decode_data.num_decodes]):
|
||||
assert req_id is not None
|
||||
# Update cached state
|
||||
for i, req_id in enumerate(decode_data.req_ids):
|
||||
req_index = self.input_batch.req_id_to_index[req_id]
|
||||
req_state = self.requests[req_id]
|
||||
|
||||
assert scheduler_output.num_scheduled_tokens[req_id] == 1
|
||||
seq_len = (req_state.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
assert seq_len == req_state.num_tokens
|
||||
|
||||
token_id = sampled_token_ids_list[i]
|
||||
self.input_batch.token_ids_cpu[i, seq_len] = token_id
|
||||
self.input_batch.num_tokens[i] += 1
|
||||
token_id = selected_token_ids_list[i]
|
||||
|
||||
self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
|
||||
self.input_batch.num_tokens[req_index] += 1
|
||||
req_state.output_token_ids.append(token_id)
|
||||
|
||||
######################### PREFILLS #########################
|
||||
# Prefills run separately with shape [1, padded_prefill_len],
|
||||
# due to lack of variable length attention kernel so far.
|
||||
for (req_id, prompt_len, token_ids, position_ids,
|
||||
attn_metadata) in prefill_data.zipped():
|
||||
assert req_id is not None
|
||||
sampled_token_ids_list[req_index] = token_id
|
||||
|
||||
# FORWARD.
|
||||
# Run each prompt
|
||||
for (req_id, prompt_len, input_tokens, input_positions,
|
||||
attn_metadata) in prompt_data.zipped():
|
||||
assert req_id is not None
|
||||
req_state = self.requests[req_id]
|
||||
req_index = self.input_batch.req_id_to_index[req_id]
|
||||
|
||||
# Forward
|
||||
with set_forward_context(attn_metadata, self.vllm_config):
|
||||
assert self.model is not None
|
||||
selected_token_ids = self.model(token_ids, position_ids,
|
||||
selected_token_ids = self.model(input_tokens, input_positions,
|
||||
attn_metadata, self.kv_caches)
|
||||
|
||||
# NOTE: TPU<>CPU sync happens here.
|
||||
# We need to call .cpu() first to avoid recompilation.
|
||||
token_id = selected_token_ids.cpu()[prompt_len - 1].item()
|
||||
sampled_token_ids_list.append(token_id)
|
||||
req_state = self.requests[req_id]
|
||||
|
||||
assert req_state.num_computed_tokens == 0
|
||||
seq_len = (req_state.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
assert seq_len == req_state.num_tokens
|
||||
assert prompt_len == seq_len
|
||||
if seq_len >= len(req_state.prompt_token_ids):
|
||||
# Transfer sampled tokens from TPU to CPU
|
||||
token_id = selected_token_ids.cpu()[prompt_len - 1].item()
|
||||
sampled_token_ids_list[req_index] = token_id
|
||||
|
||||
# UPDATE REQUEST STATE.
|
||||
req_idx = self.input_batch.req_id_to_index[req_id]
|
||||
self.input_batch.token_ids_cpu[req_idx, seq_len] = token_id
|
||||
self.input_batch.num_tokens[req_idx] += 1
|
||||
req_state.output_token_ids.append(token_id)
|
||||
# Update cached state
|
||||
self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
|
||||
self.input_batch.num_tokens[req_index] += 1
|
||||
req_state.output_token_ids.append(token_id)
|
||||
|
||||
# num_reqs entries should be non-None
|
||||
# Get req_ids
|
||||
assert all(
|
||||
req_id is not None for req_id in
|
||||
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user