mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-31 02:09:38 +08:00
[Core] Modulize prepare input and attention metadata builder (#6596)
This commit is contained in:
parent
bdf5fd1386
commit
e0c15758b8
@ -7,7 +7,6 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.sequence import SequenceGroupMetadata
|
||||
from vllm.worker.model_runner_base import ModelRunnerInputBuilderBase
|
||||
|
||||
|
||||
@ -128,25 +127,12 @@ class AttentionMetadataBuilder(ABC, Generic[T]):
|
||||
"""Abstract class for attention metadata builders."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, input_builder) -> None:
|
||||
def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_seq_group(self, seq_group_metadata: "SequenceGroupMetadata",
|
||||
token_lens: List[int], seq_lens: List[int],
|
||||
curr_seq_lens: List[int], query_lens: List[int],
|
||||
context_lens: List[int],
|
||||
curr_sliding_window_blocks: List[int],
|
||||
prefix_cache_hit: bool, chunked_prefill_enabled: bool):
|
||||
"""Add a sequence group to the metadata and update
|
||||
corresponding fields (in Python objects).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def build(self, runner: "ModelRunnerInputBuilderBase", seq_lens: List[int],
|
||||
query_lens: List[int], cuda_graph_pad_size: int,
|
||||
batch_size: int) -> T:
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int) -> T:
|
||||
"""Build attention metadata with on-device tensors."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -13,12 +13,10 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx,
|
||||
is_block_tables_empty)
|
||||
from vllm.sequence import SequenceGroupMetadata
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import (GPUModelRunnerBase,
|
||||
ModelInputForGPUBuilder)
|
||||
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
||||
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
@ -212,30 +210,30 @@ class FlashAttentionMetadataBuilder(
|
||||
self.num_prefill_tokens = 0
|
||||
self.num_decode_tokens = 0
|
||||
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
self.sliding_window = input_builder.sliding_window
|
||||
self.block_size = input_builder.block_size
|
||||
self.use_v2_block_manager = (
|
||||
input_builder.scheduler_config.use_v2_block_manager)
|
||||
|
||||
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
|
||||
token_lens: List[int], seq_lens: List[int],
|
||||
curr_seq_lens: List[int], query_lens: List[int],
|
||||
context_lens: List[int],
|
||||
curr_sliding_window_blocks: List[int],
|
||||
prefix_cache_hit: bool, chunked_prefill_enabled: bool):
|
||||
def _add_seq_group(
|
||||
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool):
|
||||
"""Add a sequence group to the metadata. Specifically update/append
|
||||
1. context length.
|
||||
2. block table.
|
||||
3. slot mapping.
|
||||
"""
|
||||
is_prompt = seq_group_metadata.is_prompt
|
||||
block_tables = seq_group_metadata.block_tables
|
||||
is_prompt = inter_data.is_prompt
|
||||
block_tables = inter_data.block_tables
|
||||
|
||||
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
||||
curr_sliding_window_block) in zip(
|
||||
seq_group_metadata.seq_data.keys(), token_lens, seq_lens,
|
||||
curr_seq_lens, query_lens, context_lens,
|
||||
curr_sliding_window_blocks):
|
||||
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
||||
inter_data.orig_seq_lens, inter_data.seq_lens,
|
||||
inter_data.query_lens, inter_data.context_lens,
|
||||
inter_data.curr_sliding_window_blocks):
|
||||
self.context_lens.append(context_len)
|
||||
|
||||
if is_prompt:
|
||||
@ -254,7 +252,7 @@ class FlashAttentionMetadataBuilder(
|
||||
# only allowing multiple of block_size chunk size.
|
||||
# NOTE: This only works for oooooooxxx style attention.
|
||||
block_table = []
|
||||
if prefix_cache_hit:
|
||||
if inter_data.prefix_cache_hit:
|
||||
# NOTE(woosuk): For flash-attn, the block table should
|
||||
# include the entries for the incoming prefill tokens.
|
||||
block_table = block_tables[seq_id]
|
||||
@ -270,16 +268,19 @@ class FlashAttentionMetadataBuilder(
|
||||
self.use_v2_block_manager)
|
||||
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
|
||||
seq_len, context_len, start_idx,
|
||||
self.block_size,
|
||||
seq_group_metadata.block_tables)
|
||||
self.block_size, inter_data.block_tables)
|
||||
|
||||
def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int):
|
||||
"""Build attention metadata with on-device tensors."""
|
||||
device = runner.device
|
||||
for inter_data in self.input_builder.inter_data_list:
|
||||
self._add_seq_group(inter_data,
|
||||
self.input_builder.chunked_prefill_enabled)
|
||||
|
||||
device = self.runner.device
|
||||
use_captured_graph = cuda_graph_pad_size != -1
|
||||
|
||||
logits_soft_cap = getattr(runner.model_config.hf_config,
|
||||
logits_soft_cap = getattr(self.runner.model_config.hf_config,
|
||||
"attn_logit_softcapping", None)
|
||||
if logits_soft_cap is not None:
|
||||
raise ValueError(
|
||||
@ -300,7 +301,7 @@ class FlashAttentionMetadataBuilder(
|
||||
|
||||
# The shape of graph_block_tables is
|
||||
# [max batch size, max context len // block size].
|
||||
input_block_tables = runner.graph_block_tables[:batch_size]
|
||||
input_block_tables = self.runner.graph_block_tables[:batch_size]
|
||||
for i, block_table in enumerate(self.block_tables):
|
||||
if block_table:
|
||||
input_block_tables[i, :len(block_table)] = block_table
|
||||
|
||||
@ -21,12 +21,10 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx,
|
||||
is_block_tables_empty)
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.sequence import SequenceGroupMetadata
|
||||
from vllm.utils import get_kv_cache_torch_dtype, make_tensor_with_pad
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import (GPUModelRunnerBase,
|
||||
ModelInputForGPUBuilder)
|
||||
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
||||
|
||||
|
||||
class FlashInferBackend(AttentionBackend):
|
||||
@ -216,6 +214,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self.num_prefill_tokens = 0
|
||||
self.num_decode_tokens = 0
|
||||
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
|
||||
self.sliding_window = input_builder.sliding_window
|
||||
self.block_size = input_builder.block_size
|
||||
self.use_v2_block_manager = (
|
||||
@ -238,26 +239,24 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
# paged_kv_last_page_len is the length of the last page of each request
|
||||
self.paged_kv_last_page_len: List[int] = []
|
||||
|
||||
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
|
||||
token_lens: List[int], seq_lens: List[int],
|
||||
curr_seq_lens: List[int], query_lens: List[int],
|
||||
context_lens: List[int],
|
||||
curr_sliding_window_blocks: List[int],
|
||||
prefix_cache_hit: bool, chunked_prefill_enabled: bool):
|
||||
def _add_seq_group(
|
||||
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool):
|
||||
"""Add a sequence group to the metadata. Specifically update/append
|
||||
1. context length.
|
||||
2. block table.
|
||||
3. slot mapping.
|
||||
"""
|
||||
is_prompt = seq_group_metadata.is_prompt
|
||||
block_tables = seq_group_metadata.block_tables
|
||||
computed_block_nums = seq_group_metadata.computed_block_nums
|
||||
is_prompt = inter_data.is_prompt
|
||||
block_tables = inter_data.block_tables
|
||||
computed_block_nums = inter_data.computed_block_nums
|
||||
|
||||
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
||||
curr_sliding_window_block) in zip(
|
||||
seq_group_metadata.seq_data.keys(), token_lens, seq_lens,
|
||||
curr_seq_lens, query_lens, context_lens,
|
||||
curr_sliding_window_blocks):
|
||||
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
||||
inter_data.orig_seq_lens, inter_data.seq_lens,
|
||||
inter_data.query_lens, inter_data.context_lens,
|
||||
inter_data.curr_sliding_window_blocks):
|
||||
self.context_lens.append(context_len)
|
||||
if is_prompt:
|
||||
self.num_prefills += 1
|
||||
@ -275,7 +274,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
# only allowing multiple of block_size chunk size.
|
||||
# NOTE: This only works for oooooooxxx style attention.
|
||||
block_table = []
|
||||
if prefix_cache_hit:
|
||||
if inter_data.prefix_cache_hit:
|
||||
block_table = computed_block_nums
|
||||
elif ((chunked_prefill_enabled or not is_prompt)
|
||||
and block_tables is not None):
|
||||
@ -290,8 +289,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self.use_v2_block_manager)
|
||||
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
|
||||
seq_len, context_len, start_idx,
|
||||
self.block_size,
|
||||
seq_group_metadata.block_tables)
|
||||
self.block_size, inter_data.block_tables)
|
||||
|
||||
# It is not necessary to add paged_kv_indices, paged_kv_indptr,
|
||||
# and paged_kv_last_page_len for profile run because we will
|
||||
@ -317,9 +315,13 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
last_page_len = self.block_size
|
||||
self.paged_kv_last_page_len.append(last_page_len)
|
||||
|
||||
def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int):
|
||||
device = runner.device
|
||||
for inter_data in self.input_builder.inter_data_list:
|
||||
self._add_seq_group(inter_data,
|
||||
self.input_builder.chunked_prefill_enabled)
|
||||
|
||||
device = self.runner.device
|
||||
use_captured_graph = cuda_graph_pad_size != -1
|
||||
|
||||
max_query_len = max(query_lens)
|
||||
@ -333,7 +335,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
# The shape of graph_block_tables is
|
||||
# [max batch size, max context len // block size].
|
||||
input_block_tables = runner.graph_block_tables[:batch_size]
|
||||
input_block_tables = self.runner.graph_block_tables[:batch_size]
|
||||
for i, block_table in enumerate(self.block_tables):
|
||||
if block_table:
|
||||
input_block_tables[i, :len(block_table)] = block_table
|
||||
@ -377,7 +379,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
logits_soft_cap = getattr(runner.model_config.hf_config,
|
||||
logits_soft_cap = getattr(self.runner.model_config.hf_config,
|
||||
"attn_logit_softcapping", None)
|
||||
|
||||
if len(self.paged_kv_indptr) > 0:
|
||||
@ -394,8 +396,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
paged_kv_indptr_tensor = None
|
||||
paged_kv_last_page_len_tensor = None
|
||||
|
||||
kv_cache_dtype = get_kv_cache_torch_dtype(runner.kv_cache_dtype,
|
||||
runner.model_config.dtype)
|
||||
kv_cache_dtype = get_kv_cache_torch_dtype(
|
||||
self.runner.kv_cache_dtype, self.runner.model_config.dtype)
|
||||
return FlashInferMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
@ -406,11 +408,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
paged_kv_indptr=paged_kv_indptr_tensor,
|
||||
paged_kv_indices=paged_kv_indices_tensor,
|
||||
paged_kv_last_page_len=paged_kv_last_page_len_tensor,
|
||||
num_qo_heads=runner.model_config.get_num_attention_heads(
|
||||
runner.parallel_config),
|
||||
num_kv_heads=runner.model_config.get_num_kv_heads(
|
||||
runner.parallel_config),
|
||||
head_dim=runner.model_config.get_head_size(),
|
||||
num_qo_heads=self.runner.model_config.get_num_attention_heads(
|
||||
self.runner.parallel_config),
|
||||
num_kv_heads=self.runner.model_config.get_num_kv_heads(
|
||||
self.runner.parallel_config),
|
||||
head_dim=self.runner.model_config.get_head_size(),
|
||||
page_size=self.block_size,
|
||||
seq_start_loc=seq_start_loc,
|
||||
query_start_loc=query_start_loc,
|
||||
|
||||
@ -4,7 +4,6 @@ from typing import TYPE_CHECKING, Dict, List, Type, TypeVar, Union
|
||||
import torch
|
||||
|
||||
from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
|
||||
from vllm.sequence import SequenceGroupMetadata
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
|
||||
# Error string(s) for encoder/decoder
|
||||
@ -15,8 +14,7 @@ STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
|
||||
PAD_SLOT_ID = -1
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import (GPUModelRunnerBase,
|
||||
ModelInputForGPUBuilder)
|
||||
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
||||
|
||||
|
||||
def is_block_tables_empty(block_tables: Union[None, Dict]):
|
||||
@ -95,26 +93,27 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
||||
self.num_prefill_tokens = 0
|
||||
self.num_decode_tokens = 0
|
||||
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
|
||||
self.sliding_window = input_builder.sliding_window
|
||||
self.block_size = input_builder.block_size
|
||||
self.use_v2_block_manager = (
|
||||
input_builder.scheduler_config.use_v2_block_manager)
|
||||
|
||||
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
|
||||
token_lens: List[int], seq_lens: List[int],
|
||||
curr_seq_lens: List[int], query_lens: List[int],
|
||||
context_lens: List[int],
|
||||
curr_sliding_window_blocks: List[int], prefix_cache_hit,
|
||||
chunked_prefill_enabled):
|
||||
is_prompt = seq_group_metadata.is_prompt
|
||||
block_tables = seq_group_metadata.block_tables
|
||||
computed_block_nums = seq_group_metadata.computed_block_nums
|
||||
def _add_seq_group(
|
||||
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool):
|
||||
is_prompt = inter_data.is_prompt
|
||||
block_tables = inter_data.block_tables
|
||||
computed_block_nums = inter_data.computed_block_nums
|
||||
|
||||
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
||||
curr_sliding_window_block) in zip(
|
||||
seq_group_metadata.seq_data.keys(), token_lens, seq_lens,
|
||||
curr_seq_lens, query_lens, context_lens,
|
||||
curr_sliding_window_blocks):
|
||||
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
||||
inter_data.orig_seq_lens, inter_data.seq_lens,
|
||||
inter_data.query_lens, inter_data.context_lens,
|
||||
inter_data.curr_sliding_window_blocks):
|
||||
self.context_lens.append(context_len)
|
||||
if is_prompt:
|
||||
self.num_prefills += 1
|
||||
@ -132,7 +131,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
||||
# only allowing multiple of block_size chunk size.
|
||||
# NOTE: This only works for oooooooxxx style attention.
|
||||
block_table = []
|
||||
if prefix_cache_hit:
|
||||
if inter_data.prefix_cache_hit:
|
||||
block_table = computed_block_nums
|
||||
elif ((chunked_prefill_enabled or not is_prompt)
|
||||
and block_tables is not None):
|
||||
@ -146,16 +145,18 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
||||
self.use_v2_block_manager)
|
||||
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
|
||||
seq_len, context_len, start_idx,
|
||||
self.block_size,
|
||||
seq_group_metadata.block_tables)
|
||||
self.block_size, inter_data.block_tables)
|
||||
|
||||
def build(self, runner: "GPUModelRunnerBase", seq_lens: List[int],
|
||||
query_lens: List[int], cuda_graph_pad_size: int,
|
||||
batch_size: int):
|
||||
device = runner.device
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int):
|
||||
for inter_data in self.input_builder.inter_data_list:
|
||||
self._add_seq_group(inter_data,
|
||||
self.input_builder.chunked_prefill_enabled)
|
||||
|
||||
device = self.runner.device
|
||||
use_captured_graph = cuda_graph_pad_size != -1
|
||||
|
||||
logits_soft_cap = getattr(runner.model_config.hf_config,
|
||||
logits_soft_cap = getattr(self.runner.model_config.hf_config,
|
||||
"attn_logit_softcapping", None)
|
||||
if logits_soft_cap is not None:
|
||||
raise ValueError(
|
||||
@ -176,7 +177,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
||||
|
||||
# The shape of graph_block_tables is
|
||||
# [max batch size, max context len // block size].
|
||||
input_block_tables = runner.graph_block_tables[:batch_size]
|
||||
input_block_tables = self.runner.graph_block_tables[:batch_size]
|
||||
for i, block_table in enumerate(self.block_tables):
|
||||
if block_table:
|
||||
input_block_tables[i, :len(block_table)] = block_table
|
||||
|
||||
@ -719,6 +719,11 @@ def merge_dicts(dict1: Dict[K, List[T]],
|
||||
return dict(merged_dict)
|
||||
|
||||
|
||||
def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
|
||||
"""Flatten a list of lists to a single list."""
|
||||
return [item for sublist in lists for item in sublist]
|
||||
|
||||
|
||||
def init_cached_hf_modules() -> None:
|
||||
"""
|
||||
Lazy initialization of the Hugging Face modules.
|
||||
|
||||
@ -3,7 +3,7 @@ import gc
|
||||
import time
|
||||
import warnings
|
||||
import weakref
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
|
||||
Tuple, Type, TypeVar, Union)
|
||||
|
||||
@ -49,7 +49,8 @@ from vllm.prompt_adapter.worker_manager import (
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (IntermediateTensors, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip,
|
||||
from vllm.utils import (CudaMemoryProfiler, flatten_2d_lists,
|
||||
get_kv_cache_torch_dtype, is_hip,
|
||||
is_pin_memory_available)
|
||||
from vllm.worker.model_runner_base import (
|
||||
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
|
||||
@ -76,7 +77,7 @@ _NUM_WARMUP_ITERS = 2
|
||||
TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU")
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@dataclass(frozen=True)
|
||||
class ModelInputForGPU(ModelRunnerInputBase):
|
||||
"""
|
||||
This base class contains metadata needed for the base model forward pass
|
||||
@ -126,7 +127,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
|
||||
return cls(**tensor_dict)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@dataclass(frozen=True)
|
||||
class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
|
||||
"""
|
||||
Used by the ModelRunner.
|
||||
@ -168,12 +169,84 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
|
||||
|
||||
|
||||
class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
"""TBA"""
|
||||
"""Build ModelInputForGPU from SequenceGroupMetadata."""
|
||||
|
||||
@dataclass
|
||||
class InterDataForSeqGroup:
|
||||
"""Intermediate data for the current sequence group."""
|
||||
# From sequence group metadata.
|
||||
request_id: str
|
||||
seq_ids: List[int]
|
||||
is_prompt: bool
|
||||
block_tables: Optional[Dict[int, List[int]]]
|
||||
computed_block_nums: List[int]
|
||||
n_seqs: int = 0
|
||||
|
||||
# Input tokens and positions.
|
||||
input_tokens: List[List[int]] = field(default_factory=list)
|
||||
input_positions: List[List[int]] = field(default_factory=list)
|
||||
|
||||
# The sequence length (may be capped to the sliding window).
|
||||
seq_lens: List[int] = field(default_factory=list)
|
||||
# The original sequence length (before applying sliding window).
|
||||
# This is used to compute slot mapping.
|
||||
orig_seq_lens: List[int] = field(default_factory=list)
|
||||
# The query length.
|
||||
query_lens: List[int] = field(default_factory=list)
|
||||
# The number of tokens that are already computed.
|
||||
context_lens: List[int] = field(default_factory=list)
|
||||
# The current sliding window block.
|
||||
curr_sliding_window_blocks: List[int] = field(default_factory=list)
|
||||
|
||||
# LoRA inputs.
|
||||
lora_index_mapping: List[List[int]] = field(default_factory=list)
|
||||
lora_prompt_mapping: List[List[int]] = field(default_factory=list)
|
||||
lora_requests: Set[LoRARequest] = field(default_factory=set)
|
||||
|
||||
# Prompt adapter inputs.
|
||||
prompt_adapter_index_mapping: List[int] = field(default_factory=list)
|
||||
prompt_adapter_prompt_mapping: List[int] = field(default_factory=list)
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
|
||||
# Multi-modal inputs.
|
||||
multi_modal_inputs: Optional[MultiModalInputs] = None
|
||||
|
||||
# Whether the prefix cache is hit (prefill only).
|
||||
prefix_cache_hit: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
self.n_seqs = len(self.seq_ids)
|
||||
|
||||
self.input_tokens = [[] for _ in range(self.n_seqs)]
|
||||
self.input_positions = [[] for _ in range(self.n_seqs)]
|
||||
self.seq_lens = [0] * self.n_seqs
|
||||
self.orig_seq_lens = [0] * self.n_seqs
|
||||
self.query_lens = [0] * self.n_seqs
|
||||
self.context_lens = [0] * self.n_seqs
|
||||
self.curr_sliding_window_blocks = [0] * self.n_seqs
|
||||
|
||||
self.lora_index_mapping = [[] for _ in range(self.n_seqs)]
|
||||
self.lora_prompt_mapping = [[] for _ in range(self.n_seqs)]
|
||||
|
||||
def __init__(self,
|
||||
runner: "GPUModelRunnerBase",
|
||||
finished_requests_ids: Optional[List[str]] = None):
|
||||
super().__init__()
|
||||
# Compute functions for each sequence in a sequence group.
|
||||
# WARNING: The order of the functions matters!
|
||||
self.per_seq_compute_fns = [
|
||||
self._compute_lens,
|
||||
self._compute_for_prefix_cache_hit,
|
||||
self._compute_for_sliding_window,
|
||||
self._compute_lora_input,
|
||||
]
|
||||
# Compute functions for each sequence group.
|
||||
# WARNING: The order of the functions matters!
|
||||
self.per_seq_group_compute_fns = [
|
||||
self._compute_prompt_adapter_input,
|
||||
self._compute_multi_modal_input,
|
||||
]
|
||||
|
||||
self.runner = runner
|
||||
self.model_input_cls = self.runner._model_input_cls
|
||||
self.attn_backend = self.runner.attn_backend
|
||||
@ -187,30 +260,14 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
self.finished_requests_ids = finished_requests_ids
|
||||
self.decode_only = True
|
||||
|
||||
# Common inputs.
|
||||
self.input_tokens: List[int] = []
|
||||
self.input_positions: List[int] = []
|
||||
self.seq_lens: List[int] = []
|
||||
self.query_lens: List[int] = []
|
||||
self.max_decode_seq_len: int = 0
|
||||
self.request_ids_to_seq_ids: Dict[str, List[int]] = defaultdict(list)
|
||||
|
||||
# LoRA inputs.
|
||||
self.lora_index_mapping: List[int] = []
|
||||
self.lora_prompt_mapping: List[int] = []
|
||||
self.lora_requests: Set[LoRARequest] = set()
|
||||
|
||||
# Prompt adapter inputs.
|
||||
self.prompt_adapter_index_mapping: List[int] = []
|
||||
self.prompt_adapter_prompt_mapping: List[int] = []
|
||||
self.prompt_adapter_requests: Set[PromptAdapterRequest] = set()
|
||||
|
||||
# Multi-modal inputs.
|
||||
self.multi_modal_inputs_list: List[MultiModalInputs] = []
|
||||
# Intermediate data (data in CPU before going to GPU) for
|
||||
# the current sequence group.
|
||||
self.inter_data_list: List[
|
||||
ModelInputForGPUBuilder.InterDataForSeqGroup] = []
|
||||
|
||||
# Attention metadata inputs.
|
||||
self.attn_metadata_builder = self.attn_backend.make_metadata_builder(
|
||||
self)
|
||||
weakref.proxy(self))
|
||||
|
||||
# Engine/Model configurations.
|
||||
self.chunked_prefill_enabled = (
|
||||
@ -222,175 +279,222 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
self.block_aligned_sliding_window = \
|
||||
self.sliding_window_blocks * self.block_size
|
||||
|
||||
def _compute_len_for_sliding_window(self, seq_len: int):
|
||||
curr_sliding_window_blocks = 0
|
||||
sliding_seq_len = seq_len
|
||||
def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int,
|
||||
seq_group_metadata: SequenceGroupMetadata):
|
||||
"""Compute context length, sequence length and tokens
|
||||
for the given sequence data.
|
||||
"""
|
||||
seq_data = seq_group_metadata.seq_data[inter_data.seq_ids[seq_idx]]
|
||||
token_chunk_size = seq_group_metadata.token_chunk_size
|
||||
|
||||
# TODO(sang): This is a hack to make sliding window work with
|
||||
# paged attn. We can remove it if we make paged attn kernel
|
||||
# to properly handle slinding window attn.
|
||||
if self.sliding_window is not None:
|
||||
curr_sliding_window_blocks = self.sliding_window_blocks
|
||||
# Compute context length (the number of tokens that are
|
||||
# already computed) and sequence length (total number of tokens).
|
||||
seq_len = seq_data.get_len()
|
||||
if inter_data.is_prompt:
|
||||
context_len = seq_data.get_num_computed_tokens()
|
||||
else:
|
||||
# get_num_computed_tokens is incorrect for spec decoding.
|
||||
# So, we should have a special logic here.
|
||||
# TODO(sang): Fix it.
|
||||
context_len = seq_len - 1
|
||||
seq_len = min(seq_len, context_len + token_chunk_size)
|
||||
|
||||
# Compute tokens.
|
||||
if inter_data.is_prompt:
|
||||
tokens = seq_data.get_token_ids()[context_len:seq_len]
|
||||
else:
|
||||
# Optimization. get_token_ids requires the entire copy of
|
||||
# tokens.
|
||||
tokens = [seq_data.get_last_token_id()]
|
||||
|
||||
inter_data.seq_lens[seq_idx] = seq_len
|
||||
inter_data.orig_seq_lens[seq_idx] = seq_len
|
||||
inter_data.context_lens[seq_idx] = context_len
|
||||
inter_data.input_tokens[seq_idx] = tokens
|
||||
inter_data.input_positions[seq_idx] = list(range(context_len, seq_len))
|
||||
inter_data.query_lens[
|
||||
seq_idx] = seq_len - context_len if inter_data.is_prompt else 1
|
||||
|
||||
def _compute_for_prefix_cache_hit(
|
||||
self, inter_data: InterDataForSeqGroup, seq_idx: int,
|
||||
seq_group_metadata: SequenceGroupMetadata):
|
||||
"""Check if hit prefix cache (i.e., some blocks are already computed).
|
||||
If hit, update input tokens and positions to only compute the
|
||||
remaining blocks.
|
||||
"""
|
||||
computed_block_nums = inter_data.computed_block_nums
|
||||
|
||||
# Note that prefix caching does not support sliding window.
|
||||
prefix_cache_hit = (computed_block_nums is not None
|
||||
and len(computed_block_nums) > 0
|
||||
and self.sliding_window is None
|
||||
and inter_data.is_prompt)
|
||||
inter_data.prefix_cache_hit = prefix_cache_hit
|
||||
if self.chunked_prefill_enabled and prefix_cache_hit:
|
||||
raise RuntimeError(
|
||||
"chunked prefill cannot be used with prefix caching now.")
|
||||
|
||||
# If prefix cache is hit, advance context length to bypass
|
||||
# hit blocks. Accordingly, input tokens, position and query length
|
||||
# have to be updated.
|
||||
if prefix_cache_hit:
|
||||
assert computed_block_nums is not None
|
||||
context_len = len(computed_block_nums) * self.block_size
|
||||
inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
|
||||
seq_idx][context_len:]
|
||||
inter_data.input_positions[seq_idx] = inter_data.input_positions[
|
||||
seq_idx][context_len:]
|
||||
inter_data.context_lens[seq_idx] = context_len
|
||||
inter_data.query_lens[
|
||||
seq_idx] = inter_data.seq_lens[seq_idx] - context_len
|
||||
|
||||
def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup,
|
||||
seq_idx: int,
|
||||
seq_group_metadata: SequenceGroupMetadata):
|
||||
"""Update seq_len and curr_sliding_window_block for the given
|
||||
sequence data (only required by decoding) if sliding window is enabled.
|
||||
"""
|
||||
curr_sliding_window_block = 0
|
||||
sliding_seq_len = inter_data.seq_lens[seq_idx]
|
||||
if not inter_data.is_prompt and self.sliding_window is not None:
|
||||
# TODO(sang): This is a hack to make sliding window work with
|
||||
# paged attn. We can remove it if we make paged attn kernel
|
||||
# to properly handle slinding window attn.
|
||||
curr_sliding_window_block = self.sliding_window_blocks
|
||||
if self.scheduler_config.use_v2_block_manager:
|
||||
# number of elements in last block
|
||||
suff_len = seq_len % self.block_size
|
||||
suff_len = inter_data.seq_lens[seq_idx] % self.block_size
|
||||
sliding_seq_len = min(
|
||||
seq_len, self.block_aligned_sliding_window + suff_len)
|
||||
inter_data.seq_lens[seq_idx],
|
||||
self.block_aligned_sliding_window + suff_len)
|
||||
if suff_len > 0:
|
||||
curr_sliding_window_blocks += 1
|
||||
curr_sliding_window_block += 1
|
||||
else:
|
||||
sliding_seq_len = min(seq_len, self.sliding_window)
|
||||
return curr_sliding_window_blocks, sliding_seq_len
|
||||
sliding_seq_len = min(inter_data.seq_lens[seq_idx],
|
||||
self.sliding_window)
|
||||
|
||||
inter_data.curr_sliding_window_blocks[
|
||||
seq_idx] = curr_sliding_window_block
|
||||
inter_data.seq_lens[seq_idx] = sliding_seq_len
|
||||
|
||||
def _compute_lora_input(self, inter_data: InterDataForSeqGroup,
|
||||
seq_idx: int,
|
||||
seq_group_metadata: SequenceGroupMetadata):
|
||||
"""If LoRA is enabled, compute LoRA index and prompt mapping."""
|
||||
if not self.enable_lora:
|
||||
return
|
||||
|
||||
lora_id = seq_group_metadata.lora_int_id
|
||||
if lora_id > 0:
|
||||
inter_data.lora_requests.add(seq_group_metadata.lora_request)
|
||||
query_len = inter_data.query_lens[seq_idx]
|
||||
inter_data.lora_index_mapping.append([lora_id] * query_len)
|
||||
inter_data.lora_prompt_mapping.append(
|
||||
[lora_id] *
|
||||
(query_len if seq_group_metadata.sampling_params
|
||||
and seq_group_metadata.sampling_params.prompt_logprobs is not None
|
||||
else 1))
|
||||
|
||||
def _compute_prompt_adapter_input(
|
||||
self, inter_data: InterDataForSeqGroup,
|
||||
seq_group_metadata: SequenceGroupMetadata):
|
||||
"""If prompt adapter is enabled, compute index and prompt mapping.
|
||||
"""
|
||||
# Note that when is_prompt=True, we expect only one sequence
|
||||
# in the group.
|
||||
if not self.enable_prompt_adapter:
|
||||
return
|
||||
|
||||
prompt_adapter_id = seq_group_metadata.prompt_adapter_id
|
||||
if prompt_adapter_id <= 0 or not inter_data.is_prompt:
|
||||
return
|
||||
|
||||
# We expect only one sequence in the group when is_prompt=True.
|
||||
assert inter_data.n_seqs == 1
|
||||
query_len = inter_data.query_lens[0]
|
||||
inter_data.prompt_adapter_request = (
|
||||
seq_group_metadata.prompt_adapter_request)
|
||||
|
||||
num_tokens = seq_group_metadata.prompt_adapter_num_virtual_tokens
|
||||
inter_data.prompt_adapter_index_mapping = [
|
||||
prompt_adapter_id
|
||||
] * num_tokens + [0] * (query_len - num_tokens)
|
||||
inter_data.prompt_adapter_prompt_mapping = [prompt_adapter_id] * (
|
||||
query_len if seq_group_metadata.sampling_params
|
||||
and seq_group_metadata.sampling_params.prompt_logprobs else 1)
|
||||
|
||||
def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup,
|
||||
seq_group_metadata: SequenceGroupMetadata):
|
||||
"""If multi-modal data is given, add it to the input."""
|
||||
mm_data = seq_group_metadata.multi_modal_data
|
||||
if not mm_data:
|
||||
return
|
||||
|
||||
mm_kwargs = self.multi_modal_input_mapper(mm_data)
|
||||
inter_data.multi_modal_inputs = mm_kwargs
|
||||
|
||||
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
|
||||
"""Add a sequence group to the builder."""
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
n_seqs = len(seq_ids)
|
||||
is_prompt = seq_group_metadata.is_prompt
|
||||
token_chunk_size = seq_group_metadata.token_chunk_size
|
||||
|
||||
if is_prompt:
|
||||
assert n_seqs == 1
|
||||
self.decode_only = False
|
||||
|
||||
# Mapping from request IDs to sequence IDs. Used for Jamba models
|
||||
# that manages the cache by itself.
|
||||
self.request_ids_to_seq_ids[seq_group_metadata.request_id] = []
|
||||
# The number of input tokens in each sequence.
|
||||
token_lens: List[int] = []
|
||||
# The number of tokens that are already computed.
|
||||
context_lens: List[int] = []
|
||||
# The current sliding window block for each sequence.
|
||||
curr_sliding_window_blocks: List[int] = []
|
||||
# The original sequence length (before applying sliding window)
|
||||
# for each sequence.
|
||||
orig_seq_lens: List[int] = []
|
||||
# The sequence length (may be capped to the sliding window).
|
||||
curr_seq_lens: List[int] = []
|
||||
for seq_id in seq_ids:
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
self.request_ids_to_seq_ids[seq_group_metadata.request_id].append(
|
||||
seq_id)
|
||||
computed_block_nums = seq_group_metadata.computed_block_nums
|
||||
inter_data = self.InterDataForSeqGroup(
|
||||
request_id=seq_group_metadata.request_id,
|
||||
seq_ids=seq_ids,
|
||||
is_prompt=is_prompt,
|
||||
block_tables=seq_group_metadata.block_tables,
|
||||
computed_block_nums=seq_group_metadata.computed_block_nums)
|
||||
self.inter_data_list.append(inter_data)
|
||||
|
||||
# Check if hit prefix cache (i.e., some blocks are already computed)
|
||||
# Note that prefix caching does not support sliding window.
|
||||
prefix_cache_hit = (computed_block_nums is not None
|
||||
and len(computed_block_nums) > 0
|
||||
and self.sliding_window is None and is_prompt)
|
||||
if self.chunked_prefill_enabled and prefix_cache_hit:
|
||||
raise RuntimeError(
|
||||
"chunked prefill cannot be used with prefix caching now.")
|
||||
|
||||
# Compute context length (the number of tokens that are
|
||||
# already computed) and sequence length (total number of tokens).
|
||||
seq_len = seq_data.get_len()
|
||||
if is_prompt:
|
||||
context_len = seq_data.get_num_computed_tokens()
|
||||
else:
|
||||
# get_num_computed_tokens is incorrect for spec decoding.
|
||||
# So, we should have a special logic here.
|
||||
# TODO(sang): Fix it.
|
||||
context_len = seq_len - 1
|
||||
seq_len = min(seq_len, context_len + token_chunk_size)
|
||||
|
||||
# Compute tokens.
|
||||
if is_prompt:
|
||||
tokens = seq_data.get_token_ids()[context_len:seq_len]
|
||||
else:
|
||||
# Optimization. get_token_ids requires the entire copy of
|
||||
# tokens.
|
||||
tokens = [seq_data.get_last_token_id()]
|
||||
if prefix_cache_hit:
|
||||
assert computed_block_nums is not None
|
||||
context_len = len(computed_block_nums) * self.block_size
|
||||
tokens = tokens[context_len:]
|
||||
|
||||
# These are seq_len/context_len capped to the sliding window.
|
||||
# They are passed to decode kernel.
|
||||
# We still need original seq_len/context_len to compute slot
|
||||
# mapping (and input position) below.
|
||||
if is_prompt:
|
||||
curr_sliding_window_block = 0
|
||||
sliding_seq_len = seq_len
|
||||
query_len = seq_len - context_len
|
||||
else:
|
||||
curr_sliding_window_block, sliding_seq_len = (
|
||||
self._compute_len_for_sliding_window(seq_len))
|
||||
query_len = 1
|
||||
|
||||
self.seq_lens.append(sliding_seq_len)
|
||||
if not is_prompt:
|
||||
self.max_decode_seq_len = max(self.max_decode_seq_len,
|
||||
sliding_seq_len)
|
||||
self.query_lens.append(query_len)
|
||||
self.input_tokens.extend(tokens)
|
||||
self.input_positions.extend(list(range(context_len, seq_len)))
|
||||
|
||||
# Intermediate data of the current sequence group for
|
||||
# the attention metadata.
|
||||
token_lens.append(len(tokens))
|
||||
context_lens.append(context_len)
|
||||
curr_seq_lens.append(sliding_seq_len)
|
||||
curr_sliding_window_blocks.append(curr_sliding_window_block)
|
||||
orig_seq_lens.append(seq_len)
|
||||
|
||||
# Update attention metadata. Note that input builder attributes
|
||||
# (self.xxx) include all added sequences, so we need to slice
|
||||
# the last n_seqs sequences.
|
||||
self.attn_metadata_builder.add_seq_group(
|
||||
seq_group_metadata, token_lens, orig_seq_lens, curr_seq_lens,
|
||||
self.query_lens[-n_seqs:], context_lens,
|
||||
curr_sliding_window_blocks, prefix_cache_hit,
|
||||
self.chunked_prefill_enabled)
|
||||
|
||||
# LoRA data.
|
||||
if self.enable_lora:
|
||||
lora_id = seq_group_metadata.lora_int_id
|
||||
for query_len in self.query_lens[-n_seqs:]:
|
||||
if lora_id > 0:
|
||||
self.lora_requests.add(seq_group_metadata.lora_request)
|
||||
self.lora_index_mapping += [lora_id] * query_len
|
||||
self.lora_prompt_mapping.extend(
|
||||
[lora_id] *
|
||||
(query_len if seq_group_metadata.sampling_params
|
||||
and seq_group_metadata.sampling_params.prompt_logprobs
|
||||
is not None else 1))
|
||||
|
||||
# Prompt adapter data. Note that when is_prompt=True,
|
||||
# we expect only one sequence in the group.
|
||||
if self.enable_prompt_adapter:
|
||||
prompt_adapter_id = seq_group_metadata.prompt_adapter_id
|
||||
if prompt_adapter_id > 0 and is_prompt:
|
||||
query_len = self.query_lens[-1]
|
||||
self.prompt_adapter_requests.add(
|
||||
seq_group_metadata.prompt_adapter_request)
|
||||
|
||||
num_tokens = seq_group_metadata.\
|
||||
prompt_adapter_num_virtual_tokens
|
||||
pm = [prompt_adapter_id
|
||||
] * num_tokens + [0] * (query_len - num_tokens)
|
||||
self.prompt_adapter_index_mapping += pm
|
||||
self.prompt_adapter_prompt_mapping.extend(
|
||||
[prompt_adapter_id] *
|
||||
(query_len if seq_group_metadata.sampling_params
|
||||
and seq_group_metadata.sampling_params.prompt_logprobs
|
||||
else 1))
|
||||
|
||||
# Multi-modal data.
|
||||
mm_data = seq_group_metadata.multi_modal_data
|
||||
if mm_data:
|
||||
mm_kwargs = self.multi_modal_input_mapper(mm_data)
|
||||
self.multi_modal_inputs_list.append(mm_kwargs)
|
||||
for seq_idx in range(n_seqs):
|
||||
for per_seq_fn in self.per_seq_compute_fns:
|
||||
per_seq_fn(inter_data, seq_idx, seq_group_metadata)
|
||||
for per_seq_group_fn in self.per_seq_group_compute_fns:
|
||||
per_seq_group_fn(inter_data, seq_group_metadata)
|
||||
|
||||
def build(self) -> ModelInputForGPU:
|
||||
if not self.input_tokens:
|
||||
"""Finalize the builder intermediate data and
|
||||
create on-device tensors.
|
||||
"""
|
||||
# Combine and flatten intermediate data.
|
||||
input_tokens = flatten_2d_lists([
|
||||
flatten_2d_lists(inter_data.input_tokens)
|
||||
for inter_data in self.inter_data_list
|
||||
])
|
||||
if not input_tokens:
|
||||
# This may happen when all prefill requests hit
|
||||
# prefix caching and there is no decode request.
|
||||
return self.model_input_cls()
|
||||
input_positions = flatten_2d_lists([
|
||||
flatten_2d_lists(inter_data.input_positions)
|
||||
for inter_data in self.inter_data_list
|
||||
])
|
||||
seq_lens = []
|
||||
max_decode_seq_len = 0
|
||||
for inter_data in self.inter_data_list:
|
||||
seq_lens.extend(inter_data.seq_lens)
|
||||
if not inter_data.is_prompt:
|
||||
max_decode_seq_len = max(max_decode_seq_len,
|
||||
max(inter_data.seq_lens))
|
||||
query_lens = flatten_2d_lists(
|
||||
[inter_data.query_lens for inter_data in self.inter_data_list])
|
||||
# Mapping from request IDs to sequence IDs. Used for Jamba models
|
||||
# that manages the cache by itself.
|
||||
request_ids_to_seq_ids = {
|
||||
data.request_id: data.seq_ids
|
||||
for data in self.inter_data_list
|
||||
}
|
||||
|
||||
batch_size = len(self.input_tokens)
|
||||
batch_size = len(input_tokens)
|
||||
use_captured_graph = (
|
||||
self.decode_only and not self.runner.model_config.enforce_eager
|
||||
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
|
||||
and self.max_decode_seq_len <= self.runner.max_seq_len_to_capture)
|
||||
and max_decode_seq_len <= self.runner.max_seq_len_to_capture)
|
||||
|
||||
# If cuda graph can be used, pad tensors accordingly.
|
||||
# See `capture_model` API for more details.
|
||||
@ -403,60 +507,84 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
batch_size = graph_batch_size
|
||||
|
||||
# Tokens and positions.
|
||||
self.input_tokens.extend([0] * cuda_graph_pad_size)
|
||||
self.input_positions.extend([0] * cuda_graph_pad_size)
|
||||
input_tokens_tensor = torch.tensor(self.input_tokens,
|
||||
input_tokens.extend([0] * cuda_graph_pad_size)
|
||||
input_positions.extend([0] * cuda_graph_pad_size)
|
||||
input_tokens_tensor = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.runner.device)
|
||||
input_positions_tensor = torch.tensor(self.input_positions,
|
||||
input_positions_tensor = torch.tensor(input_positions,
|
||||
dtype=torch.long,
|
||||
device=self.runner.device)
|
||||
|
||||
# Sequence and query lengths.
|
||||
self.seq_lens.extend([1] * cuda_graph_pad_size)
|
||||
seq_lens.extend([1] * cuda_graph_pad_size)
|
||||
|
||||
# Attention metadata.
|
||||
attn_metadata = self.attn_metadata_builder.build(
|
||||
self.runner, self.seq_lens, self.query_lens, cuda_graph_pad_size,
|
||||
batch_size)
|
||||
seq_lens, query_lens, cuda_graph_pad_size, batch_size)
|
||||
|
||||
# LoRA data.
|
||||
lora_requests = set()
|
||||
lora_mapping = None
|
||||
if self.enable_lora:
|
||||
self.lora_index_mapping.extend([0] * cuda_graph_pad_size)
|
||||
lora_requests = set(r for data in self.inter_data_list
|
||||
for r in data.lora_requests)
|
||||
lora_index_mapping = flatten_2d_lists([
|
||||
flatten_2d_lists(inter_data.lora_index_mapping)
|
||||
for inter_data in self.inter_data_list
|
||||
])
|
||||
lora_index_mapping.extend([0] * cuda_graph_pad_size)
|
||||
lora_prompt_mapping = flatten_2d_lists([
|
||||
flatten_2d_lists(inter_data.lora_prompt_mapping)
|
||||
for inter_data in self.inter_data_list
|
||||
])
|
||||
lora_mapping = LoRAMapping(
|
||||
self.lora_index_mapping,
|
||||
self.lora_prompt_mapping,
|
||||
lora_index_mapping,
|
||||
lora_prompt_mapping,
|
||||
)
|
||||
else:
|
||||
lora_mapping = None
|
||||
|
||||
# Prompt adapter data.
|
||||
prompt_adapter_requests: Set[PromptAdapterRequest] = set()
|
||||
prompt_adapter_mapping = None
|
||||
if self.enable_prompt_adapter:
|
||||
self.prompt_adapter_index_mapping.extend([0] * cuda_graph_pad_size)
|
||||
prompt_adapter_requests = set(
|
||||
data.prompt_adapter_request for data in self.inter_data_list
|
||||
if data.prompt_adapter_request is not None)
|
||||
prompt_adapter_index_mapping = flatten_2d_lists([
|
||||
inter_data.prompt_adapter_index_mapping
|
||||
for inter_data in self.inter_data_list
|
||||
])
|
||||
prompt_adapter_index_mapping.extend([0] * cuda_graph_pad_size)
|
||||
prompt_adapter_prompt_mapping = flatten_2d_lists([
|
||||
inter_data.prompt_adapter_prompt_mapping
|
||||
for inter_data in self.inter_data_list
|
||||
])
|
||||
prompt_adapter_mapping = PromptAdapterMapping(
|
||||
self.prompt_adapter_index_mapping,
|
||||
self.prompt_adapter_prompt_mapping,
|
||||
prompt_adapter_index_mapping,
|
||||
prompt_adapter_prompt_mapping,
|
||||
)
|
||||
else:
|
||||
prompt_adapter_mapping = None
|
||||
|
||||
# Multi-modal data.
|
||||
multi_modal_kwargs = MultiModalInputs.batch(
|
||||
self.multi_modal_inputs_list, device=self.runner.device)
|
||||
multi_modal_inputs_list = [
|
||||
data.multi_modal_inputs for data in self.inter_data_list
|
||||
if data.multi_modal_inputs is not None
|
||||
]
|
||||
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
|
||||
device=self.runner.device)
|
||||
|
||||
return self.model_input_cls(
|
||||
input_tokens=input_tokens_tensor,
|
||||
input_positions=input_positions_tensor,
|
||||
attn_metadata=attn_metadata,
|
||||
seq_lens=self.seq_lens,
|
||||
query_lens=self.query_lens,
|
||||
seq_lens=seq_lens,
|
||||
query_lens=query_lens,
|
||||
lora_mapping=lora_mapping,
|
||||
lora_requests=self.lora_requests,
|
||||
lora_requests=lora_requests,
|
||||
multi_modal_kwargs=multi_modal_kwargs,
|
||||
request_ids_to_seq_ids=self.request_ids_to_seq_ids,
|
||||
request_ids_to_seq_ids=request_ids_to_seq_ids,
|
||||
finished_requests_ids=self.finished_requests_ids,
|
||||
prompt_adapter_mapping=prompt_adapter_mapping,
|
||||
prompt_adapter_requests=self.prompt_adapter_requests)
|
||||
prompt_adapter_requests=prompt_adapter_requests)
|
||||
|
||||
|
||||
class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
@ -1393,15 +1521,3 @@ def _get_graph_batch_size(batch_size: int) -> int:
|
||||
else:
|
||||
return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
|
||||
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
|
||||
|
||||
|
||||
def _is_block_tables_empty(block_tables: Union[None, Dict]):
|
||||
"""
|
||||
Check if block_tables is None or a dictionary with all None values.
|
||||
"""
|
||||
if block_tables is None:
|
||||
return True
|
||||
if isinstance(block_tables, dict) and all(
|
||||
value is None for value in block_tables.values()):
|
||||
return True
|
||||
return False
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user