mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-28 08:07:03 +08:00
[V1] Input Batch Relocation (#10962)
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com> Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
parent
aea2fc38c3
commit
25b79d9fd3
280
vllm/v1/worker/gpu_input_batch.py
Normal file
280
vllm/v1/worker/gpu_input_batch.py
Normal file
@ -0,0 +1,280 @@
|
|||||||
|
# Datastructures defining an input batch
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Dict, List, Optional, Set
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.multimodal import MultiModalKwargs
|
||||||
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.multimodal.inputs import PlaceholderRange
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CachedRequestState:
|
||||||
|
|
||||||
|
req_id: str
|
||||||
|
prompt_token_ids: List[int]
|
||||||
|
prompt: Optional[str]
|
||||||
|
mm_inputs: List[MultiModalKwargs]
|
||||||
|
mm_positions: List["PlaceholderRange"]
|
||||||
|
sampling_params: SamplingParams
|
||||||
|
generator: Optional[torch.Generator]
|
||||||
|
|
||||||
|
block_ids: List[int]
|
||||||
|
num_computed_tokens: int
|
||||||
|
output_token_ids: List[int]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_tokens(self) -> int:
|
||||||
|
return len(self.prompt_token_ids) + len(self.output_token_ids)
|
||||||
|
|
||||||
|
|
||||||
|
class InputBatch:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_num_reqs: int,
|
||||||
|
max_model_len: int,
|
||||||
|
max_num_blocks_per_req: int,
|
||||||
|
device: torch.device,
|
||||||
|
pin_memory: bool,
|
||||||
|
):
|
||||||
|
self.max_num_reqs = max_num_reqs
|
||||||
|
self.max_model_len = max_model_len
|
||||||
|
self.max_num_blocks_per_req = max_num_blocks_per_req
|
||||||
|
self.device = device
|
||||||
|
self.pin_memory = pin_memory
|
||||||
|
|
||||||
|
self.req_ids: List[Optional[str]] = [None] * max_num_reqs
|
||||||
|
self.req_id_to_index: Dict[str, int] = {}
|
||||||
|
|
||||||
|
self.token_ids_cpu = np.empty((max_num_reqs, max_model_len),
|
||||||
|
dtype=np.int32)
|
||||||
|
self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)
|
||||||
|
|
||||||
|
# Attention-related.
|
||||||
|
self.block_table = torch.zeros((max_num_reqs, max_num_blocks_per_req),
|
||||||
|
device=self.device,
|
||||||
|
dtype=torch.int32)
|
||||||
|
self.block_table_cpu_tensor = torch.zeros(
|
||||||
|
(max_num_reqs, max_num_blocks_per_req),
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.int32,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
)
|
||||||
|
self.block_table_cpu = self.block_table_cpu_tensor.numpy()
|
||||||
|
|
||||||
|
# Sampling-related.
|
||||||
|
self.temperature = torch.empty((max_num_reqs, ),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device)
|
||||||
|
self.temperature_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device="cpu",
|
||||||
|
pin_memory=pin_memory)
|
||||||
|
self.temperature_cpu = self.temperature_cpu_tensor.numpy()
|
||||||
|
self.greedy_reqs: Set[str] = set()
|
||||||
|
self.random_reqs: Set[str] = set()
|
||||||
|
|
||||||
|
self.top_p = torch.empty((max_num_reqs, ),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device)
|
||||||
|
self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device="cpu",
|
||||||
|
pin_memory=pin_memory)
|
||||||
|
self.top_p_cpu = self.top_p_cpu_tensor.numpy()
|
||||||
|
self.top_p_reqs: Set[str] = set()
|
||||||
|
|
||||||
|
self.top_k = torch.empty((max_num_reqs, ),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu",
|
||||||
|
pin_memory=pin_memory)
|
||||||
|
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
|
||||||
|
self.top_k_reqs: Set[str] = set()
|
||||||
|
|
||||||
|
# req_index -> generator
|
||||||
|
self.generators: Dict[int, torch.Generator] = {}
|
||||||
|
|
||||||
|
self.num_logprobs: Dict[str, int] = {}
|
||||||
|
self.prompt_logprob_reqs: Set[str] = set()
|
||||||
|
|
||||||
|
def add_request(
|
||||||
|
self,
|
||||||
|
request: "CachedRequestState",
|
||||||
|
req_index: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
if req_index is None:
|
||||||
|
req_index = self.num_reqs
|
||||||
|
assert req_index < self.max_num_reqs
|
||||||
|
|
||||||
|
req_id = request.req_id
|
||||||
|
self.req_ids[req_index] = req_id
|
||||||
|
self.req_id_to_index[req_id] = req_index
|
||||||
|
|
||||||
|
# Copy the prompt token ids and output token ids.
|
||||||
|
num_prompt_tokens = len(request.prompt_token_ids)
|
||||||
|
self.token_ids_cpu[
|
||||||
|
req_index, :num_prompt_tokens] = request.prompt_token_ids
|
||||||
|
start_idx = num_prompt_tokens
|
||||||
|
end_idx = start_idx + len(request.output_token_ids)
|
||||||
|
self.token_ids_cpu[req_index,
|
||||||
|
start_idx:end_idx] = request.output_token_ids
|
||||||
|
|
||||||
|
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
|
||||||
|
num_blocks = len(request.block_ids)
|
||||||
|
self.block_table_cpu[req_index, :num_blocks] = request.block_ids
|
||||||
|
|
||||||
|
sampling_params = request.sampling_params
|
||||||
|
self.temperature_cpu[req_index] = sampling_params.temperature
|
||||||
|
if sampling_params.sampling_type == SamplingType.GREEDY:
|
||||||
|
self.greedy_reqs.add(req_id)
|
||||||
|
else:
|
||||||
|
self.random_reqs.add(req_id)
|
||||||
|
|
||||||
|
self.top_p_cpu[req_index] = sampling_params.top_p
|
||||||
|
if sampling_params.top_p < 1:
|
||||||
|
self.top_p_reqs.add(req_id)
|
||||||
|
self.top_k_cpu[req_index] = sampling_params.top_k
|
||||||
|
if sampling_params.top_k > 0:
|
||||||
|
self.top_k_reqs.add(req_id)
|
||||||
|
|
||||||
|
self.generators[req_index] = request.generator
|
||||||
|
|
||||||
|
num_logprobs = sampling_params.logprobs
|
||||||
|
if num_logprobs is not None and num_logprobs > 0:
|
||||||
|
self.num_logprobs[req_id] = num_logprobs
|
||||||
|
if sampling_params.prompt_logprobs:
|
||||||
|
self.prompt_logprob_reqs.add(req_id)
|
||||||
|
|
||||||
|
def remove_request(self, req_id: str) -> Optional[int]:
|
||||||
|
req_index = self.req_id_to_index.pop(req_id, None)
|
||||||
|
if req_index is None:
|
||||||
|
return None
|
||||||
|
self.req_ids[req_index] = None
|
||||||
|
|
||||||
|
self.greedy_reqs.discard(req_id)
|
||||||
|
self.random_reqs.discard(req_id)
|
||||||
|
self.top_p_reqs.discard(req_id)
|
||||||
|
self.top_k_reqs.discard(req_id)
|
||||||
|
self.generators.pop(req_index, None)
|
||||||
|
self.num_logprobs.pop(req_id, None)
|
||||||
|
self.prompt_logprob_reqs.discard(req_id)
|
||||||
|
return req_index
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
self.req_ids = [None] * self.max_num_reqs
|
||||||
|
self.req_id_to_index.clear()
|
||||||
|
self.greedy_reqs.clear()
|
||||||
|
self.random_reqs.clear()
|
||||||
|
self.top_p_reqs.clear()
|
||||||
|
self.top_k_reqs.clear()
|
||||||
|
self.generators.clear()
|
||||||
|
self.num_logprobs.clear()
|
||||||
|
self.prompt_logprob_reqs.clear()
|
||||||
|
|
||||||
|
def condense(self, empty_req_indices: List[int]) -> None:
|
||||||
|
if self.num_reqs == 0:
|
||||||
|
# The batched states are empty.
|
||||||
|
return
|
||||||
|
|
||||||
|
# NOTE(woosuk): This function assumes that the empty_req_indices
|
||||||
|
# is sorted in descending order.
|
||||||
|
last_req_index = self.num_reqs + len(empty_req_indices) - 1
|
||||||
|
while empty_req_indices:
|
||||||
|
# Find the largest non-empty index.
|
||||||
|
while last_req_index in empty_req_indices:
|
||||||
|
last_req_index -= 1
|
||||||
|
|
||||||
|
# Find the smallest empty index.
|
||||||
|
empty_index = empty_req_indices.pop()
|
||||||
|
if empty_index >= last_req_index:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Swap the states.
|
||||||
|
req_id = self.req_ids[last_req_index]
|
||||||
|
self.req_ids[empty_index] = req_id
|
||||||
|
self.req_ids[last_req_index] = None
|
||||||
|
self.req_id_to_index[req_id] = empty_index
|
||||||
|
|
||||||
|
# TODO(woosuk): Optimize the copy of token_ids_cpu and
|
||||||
|
# block_table_cpu.
|
||||||
|
self.token_ids_cpu[empty_index] = self.token_ids_cpu[
|
||||||
|
last_req_index]
|
||||||
|
self.num_computed_tokens_cpu[
|
||||||
|
empty_index] = self.num_computed_tokens_cpu[last_req_index]
|
||||||
|
self.block_table_cpu[empty_index] = self.block_table_cpu[
|
||||||
|
last_req_index]
|
||||||
|
self.temperature_cpu[empty_index] = self.temperature_cpu[
|
||||||
|
last_req_index]
|
||||||
|
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
|
||||||
|
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
|
||||||
|
generator = self.generators.pop(last_req_index, None)
|
||||||
|
if generator is not None:
|
||||||
|
self.generators[empty_index] = generator
|
||||||
|
|
||||||
|
# Decrement last_req_index since it is now empty.
|
||||||
|
last_req_index -= 1
|
||||||
|
|
||||||
|
def make_sampling_metadata(
|
||||||
|
self,
|
||||||
|
skip_copy: bool = False,
|
||||||
|
) -> SamplingMetadata:
|
||||||
|
if not skip_copy:
|
||||||
|
self.temperature[:self.num_reqs].copy_(
|
||||||
|
self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True)
|
||||||
|
self.top_p[:self.num_reqs].copy_(
|
||||||
|
self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True)
|
||||||
|
self.top_k[:self.num_reqs].copy_(
|
||||||
|
self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True)
|
||||||
|
return SamplingMetadata(
|
||||||
|
temperature=self.temperature[:self.num_reqs],
|
||||||
|
all_greedy=self.all_greedy,
|
||||||
|
all_random=self.all_random,
|
||||||
|
top_p=self.top_p[:self.num_reqs],
|
||||||
|
top_k=self.top_k[:self.num_reqs],
|
||||||
|
no_top_p=self.no_top_p,
|
||||||
|
no_top_k=self.no_top_k,
|
||||||
|
generators=self.generators,
|
||||||
|
max_num_logprobs=self.max_num_logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_reqs(self) -> int:
|
||||||
|
return len(self.req_id_to_index)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def all_greedy(self) -> bool:
|
||||||
|
return len(self.random_reqs) == 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def all_random(self) -> bool:
|
||||||
|
return len(self.greedy_reqs) == 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def no_top_p(self) -> bool:
|
||||||
|
return len(self.top_p_reqs) == 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def no_top_k(self) -> bool:
|
||||||
|
return len(self.top_k_reqs) == 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_num_logprobs(self) -> int:
|
||||||
|
return max(self.num_logprobs.values()) if self.num_logprobs else 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def no_logprob(self) -> bool:
|
||||||
|
return len(self.num_logprobs) == 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def no_prompt_logprob(self) -> bool:
|
||||||
|
return len(self.prompt_logprob_reqs) == 0
|
||||||
@ -1,7 +1,6 @@
|
|||||||
import gc
|
import gc
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -15,16 +14,16 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.multimodal import MultiModalKwargs
|
from vllm.multimodal import MultiModalKwargs
|
||||||
from vllm.sampling_params import SamplingParams, SamplingType
|
from vllm.sampling_params import SamplingType
|
||||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv,
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv,
|
||||||
is_pin_memory_available)
|
is_pin_memory_available)
|
||||||
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
|
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
|
||||||
FlashAttentionMetadata)
|
FlashAttentionMetadata)
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.multimodal.inputs import PlaceholderRange
|
|
||||||
from vllm.v1.core.scheduler import SchedulerOutput
|
from vllm.v1.core.scheduler import SchedulerOutput
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -609,269 +608,3 @@ class GPUModelRunner:
|
|||||||
if batch_size <= size:
|
if batch_size <= size:
|
||||||
return size
|
return size
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class CachedRequestState:
|
|
||||||
|
|
||||||
req_id: str
|
|
||||||
prompt_token_ids: List[int]
|
|
||||||
prompt: Optional[str]
|
|
||||||
mm_inputs: List[MultiModalKwargs]
|
|
||||||
mm_positions: List["PlaceholderRange"]
|
|
||||||
sampling_params: SamplingParams
|
|
||||||
generator: Optional[torch.Generator]
|
|
||||||
|
|
||||||
block_ids: List[int]
|
|
||||||
num_computed_tokens: int
|
|
||||||
output_token_ids: List[int]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_tokens(self) -> int:
|
|
||||||
return len(self.prompt_token_ids) + len(self.output_token_ids)
|
|
||||||
|
|
||||||
|
|
||||||
class InputBatch:
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
max_num_reqs: int,
|
|
||||||
max_model_len: int,
|
|
||||||
max_num_blocks_per_req: int,
|
|
||||||
device: torch.device,
|
|
||||||
pin_memory: bool,
|
|
||||||
):
|
|
||||||
self.max_num_reqs = max_num_reqs
|
|
||||||
self.max_model_len = max_model_len
|
|
||||||
self.max_num_blocks_per_req = max_num_blocks_per_req
|
|
||||||
self.device = device
|
|
||||||
self.pin_memory = pin_memory
|
|
||||||
|
|
||||||
self.req_ids: List[Optional[str]] = [None] * max_num_reqs
|
|
||||||
self.req_id_to_index: Dict[str, int] = {}
|
|
||||||
|
|
||||||
self.token_ids_cpu = np.empty((max_num_reqs, max_model_len),
|
|
||||||
dtype=np.int32)
|
|
||||||
self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)
|
|
||||||
|
|
||||||
# Attention-related.
|
|
||||||
self.block_table = torch.zeros((max_num_reqs, max_num_blocks_per_req),
|
|
||||||
device=self.device,
|
|
||||||
dtype=torch.int32)
|
|
||||||
self.block_table_cpu_tensor = torch.zeros(
|
|
||||||
(max_num_reqs, max_num_blocks_per_req),
|
|
||||||
device="cpu",
|
|
||||||
dtype=torch.int32,
|
|
||||||
pin_memory=pin_memory,
|
|
||||||
)
|
|
||||||
self.block_table_cpu = self.block_table_cpu_tensor.numpy()
|
|
||||||
|
|
||||||
# Sampling-related.
|
|
||||||
self.temperature = torch.empty((max_num_reqs, ),
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=device)
|
|
||||||
self.temperature_cpu_tensor = torch.empty((max_num_reqs, ),
|
|
||||||
dtype=torch.float32,
|
|
||||||
device="cpu",
|
|
||||||
pin_memory=pin_memory)
|
|
||||||
self.temperature_cpu = self.temperature_cpu_tensor.numpy()
|
|
||||||
self.greedy_reqs: Set[str] = set()
|
|
||||||
self.random_reqs: Set[str] = set()
|
|
||||||
|
|
||||||
self.top_p = torch.empty((max_num_reqs, ),
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=device)
|
|
||||||
self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
|
|
||||||
dtype=torch.float32,
|
|
||||||
device="cpu",
|
|
||||||
pin_memory=pin_memory)
|
|
||||||
self.top_p_cpu = self.top_p_cpu_tensor.numpy()
|
|
||||||
self.top_p_reqs: Set[str] = set()
|
|
||||||
|
|
||||||
self.top_k = torch.empty((max_num_reqs, ),
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=device)
|
|
||||||
self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
|
|
||||||
dtype=torch.int32,
|
|
||||||
device="cpu",
|
|
||||||
pin_memory=pin_memory)
|
|
||||||
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
|
|
||||||
self.top_k_reqs: Set[str] = set()
|
|
||||||
|
|
||||||
# req_index -> generator
|
|
||||||
self.generators: Dict[int, torch.Generator] = {}
|
|
||||||
|
|
||||||
self.num_logprobs: Dict[str, int] = {}
|
|
||||||
self.prompt_logprob_reqs: Set[str] = set()
|
|
||||||
|
|
||||||
def add_request(
|
|
||||||
self,
|
|
||||||
request: "CachedRequestState",
|
|
||||||
req_index: Optional[int] = None,
|
|
||||||
) -> None:
|
|
||||||
if req_index is None:
|
|
||||||
req_index = self.num_reqs
|
|
||||||
assert req_index < self.max_num_reqs
|
|
||||||
|
|
||||||
req_id = request.req_id
|
|
||||||
self.req_ids[req_index] = req_id
|
|
||||||
self.req_id_to_index[req_id] = req_index
|
|
||||||
|
|
||||||
# Copy the prompt token ids and output token ids.
|
|
||||||
num_prompt_tokens = len(request.prompt_token_ids)
|
|
||||||
self.token_ids_cpu[
|
|
||||||
req_index, :num_prompt_tokens] = request.prompt_token_ids
|
|
||||||
start_idx = num_prompt_tokens
|
|
||||||
end_idx = start_idx + len(request.output_token_ids)
|
|
||||||
self.token_ids_cpu[req_index,
|
|
||||||
start_idx:end_idx] = request.output_token_ids
|
|
||||||
|
|
||||||
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
|
|
||||||
num_blocks = len(request.block_ids)
|
|
||||||
self.block_table_cpu[req_index, :num_blocks] = request.block_ids
|
|
||||||
|
|
||||||
sampling_params = request.sampling_params
|
|
||||||
self.temperature_cpu[req_index] = sampling_params.temperature
|
|
||||||
if sampling_params.sampling_type == SamplingType.GREEDY:
|
|
||||||
self.greedy_reqs.add(req_id)
|
|
||||||
else:
|
|
||||||
self.random_reqs.add(req_id)
|
|
||||||
|
|
||||||
self.top_p_cpu[req_index] = sampling_params.top_p
|
|
||||||
if sampling_params.top_p < 1:
|
|
||||||
self.top_p_reqs.add(req_id)
|
|
||||||
self.top_k_cpu[req_index] = sampling_params.top_k
|
|
||||||
if sampling_params.top_k > 0:
|
|
||||||
self.top_k_reqs.add(req_id)
|
|
||||||
|
|
||||||
self.generators[req_index] = request.generator
|
|
||||||
|
|
||||||
num_logprobs = sampling_params.logprobs
|
|
||||||
if num_logprobs is not None and num_logprobs > 0:
|
|
||||||
self.num_logprobs[req_id] = num_logprobs
|
|
||||||
if sampling_params.prompt_logprobs:
|
|
||||||
self.prompt_logprob_reqs.add(req_id)
|
|
||||||
|
|
||||||
def remove_request(self, req_id: str) -> Optional[int]:
|
|
||||||
req_index = self.req_id_to_index.pop(req_id, None)
|
|
||||||
if req_index is None:
|
|
||||||
return None
|
|
||||||
self.req_ids[req_index] = None
|
|
||||||
|
|
||||||
self.greedy_reqs.discard(req_id)
|
|
||||||
self.random_reqs.discard(req_id)
|
|
||||||
self.top_p_reqs.discard(req_id)
|
|
||||||
self.top_k_reqs.discard(req_id)
|
|
||||||
self.generators.pop(req_index, None)
|
|
||||||
self.num_logprobs.pop(req_id, None)
|
|
||||||
self.prompt_logprob_reqs.discard(req_id)
|
|
||||||
return req_index
|
|
||||||
|
|
||||||
def clear(self) -> None:
|
|
||||||
self.req_ids = [None] * self.max_num_reqs
|
|
||||||
self.req_id_to_index.clear()
|
|
||||||
self.greedy_reqs.clear()
|
|
||||||
self.random_reqs.clear()
|
|
||||||
self.top_p_reqs.clear()
|
|
||||||
self.top_k_reqs.clear()
|
|
||||||
self.generators.clear()
|
|
||||||
self.num_logprobs.clear()
|
|
||||||
self.prompt_logprob_reqs.clear()
|
|
||||||
|
|
||||||
def condense(self, empty_req_indices: List[int]) -> None:
|
|
||||||
if self.num_reqs == 0:
|
|
||||||
# The batched states are empty.
|
|
||||||
return
|
|
||||||
|
|
||||||
# NOTE(woosuk): This function assumes that the empty_req_indices
|
|
||||||
# is sorted in descending order.
|
|
||||||
last_req_index = self.num_reqs + len(empty_req_indices) - 1
|
|
||||||
while empty_req_indices:
|
|
||||||
# Find the largest non-empty index.
|
|
||||||
while last_req_index in empty_req_indices:
|
|
||||||
last_req_index -= 1
|
|
||||||
|
|
||||||
# Find the smallest empty index.
|
|
||||||
empty_index = empty_req_indices.pop()
|
|
||||||
if empty_index >= last_req_index:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Swap the states.
|
|
||||||
req_id = self.req_ids[last_req_index]
|
|
||||||
self.req_ids[empty_index] = req_id
|
|
||||||
self.req_ids[last_req_index] = None
|
|
||||||
self.req_id_to_index[req_id] = empty_index
|
|
||||||
|
|
||||||
# TODO(woosuk): Optimize the copy of token_ids_cpu and
|
|
||||||
# block_table_cpu.
|
|
||||||
self.token_ids_cpu[empty_index] = self.token_ids_cpu[
|
|
||||||
last_req_index]
|
|
||||||
self.num_computed_tokens_cpu[
|
|
||||||
empty_index] = self.num_computed_tokens_cpu[last_req_index]
|
|
||||||
self.block_table_cpu[empty_index] = self.block_table_cpu[
|
|
||||||
last_req_index]
|
|
||||||
self.temperature_cpu[empty_index] = self.temperature_cpu[
|
|
||||||
last_req_index]
|
|
||||||
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
|
|
||||||
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
|
|
||||||
generator = self.generators.pop(last_req_index, None)
|
|
||||||
if generator is not None:
|
|
||||||
self.generators[empty_index] = generator
|
|
||||||
|
|
||||||
# Decrement last_req_index since it is now empty.
|
|
||||||
last_req_index -= 1
|
|
||||||
|
|
||||||
def make_sampling_metadata(
|
|
||||||
self,
|
|
||||||
skip_copy: bool = False,
|
|
||||||
) -> SamplingMetadata:
|
|
||||||
if not skip_copy:
|
|
||||||
self.temperature[:self.num_reqs].copy_(
|
|
||||||
self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True)
|
|
||||||
self.top_p[:self.num_reqs].copy_(
|
|
||||||
self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True)
|
|
||||||
self.top_k[:self.num_reqs].copy_(
|
|
||||||
self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True)
|
|
||||||
return SamplingMetadata(
|
|
||||||
temperature=self.temperature[:self.num_reqs],
|
|
||||||
all_greedy=self.all_greedy,
|
|
||||||
all_random=self.all_random,
|
|
||||||
top_p=self.top_p[:self.num_reqs],
|
|
||||||
top_k=self.top_k[:self.num_reqs],
|
|
||||||
no_top_p=self.no_top_p,
|
|
||||||
no_top_k=self.no_top_k,
|
|
||||||
generators=self.generators,
|
|
||||||
max_num_logprobs=self.max_num_logprobs,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_reqs(self) -> int:
|
|
||||||
return len(self.req_id_to_index)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def all_greedy(self) -> bool:
|
|
||||||
return len(self.random_reqs) == 0
|
|
||||||
|
|
||||||
@property
|
|
||||||
def all_random(self) -> bool:
|
|
||||||
return len(self.greedy_reqs) == 0
|
|
||||||
|
|
||||||
@property
|
|
||||||
def no_top_p(self) -> bool:
|
|
||||||
return len(self.top_p_reqs) == 0
|
|
||||||
|
|
||||||
@property
|
|
||||||
def no_top_k(self) -> bool:
|
|
||||||
return len(self.top_k_reqs) == 0
|
|
||||||
|
|
||||||
@property
|
|
||||||
def max_num_logprobs(self) -> int:
|
|
||||||
return max(self.num_logprobs.values()) if self.num_logprobs else 0
|
|
||||||
|
|
||||||
@property
|
|
||||||
def no_logprob(self) -> bool:
|
|
||||||
return len(self.num_logprobs) == 0
|
|
||||||
|
|
||||||
@property
|
|
||||||
def no_prompt_logprob(self) -> bool:
|
|
||||||
return len(self.prompt_logprob_reqs) == 0
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user