diff --git a/cacheflow/models/__init__.py b/cacheflow/models/__init__.py index f0c68e5b177e7..498101b53fdd7 100644 --- a/cacheflow/models/__init__.py +++ b/cacheflow/models/__init__.py @@ -1,7 +1,8 @@ -from cacheflow.worker.models.model_utils import get_model +from cacheflow.models.input_metadata import InputMetadata +from cacheflow.models.model_utils import get_model __all__ = [ 'get_model', - + 'InputMetadata', ] diff --git a/cacheflow/models/input_metadata.py b/cacheflow/models/input_metadata.py new file mode 100644 index 0000000000000..253b4389dd5aa --- /dev/null +++ b/cacheflow/models/input_metadata.py @@ -0,0 +1,25 @@ +from typing import List + +import torch + + +class InputMetadata: + + def __init__( + self, + prompt_lens: List[int], + slot_mapping: torch.Tensor, + context_lens: torch.Tensor, + max_context_len: int, + block_tables: torch.Tensor, + ) -> None: + self.prompt_lens = prompt_lens + self.prompt_block_table = slot_mapping + self.context_lens = context_lens + self.max_context_len = max_context_len + self.block_tables = block_tables + + self.num_prompts = len(prompt_lens) + self.num_generation_tokens = context_lens.shape[0] + self.max_num_blocks_per_seq = block_tables.shape[1] + assert self.num_generation_tokens == block_tables.shape[0]