diff --git a/cacheflow/block.py b/cacheflow/block.py index 01edf6177755..47e2cfc683c2 100644 --- a/cacheflow/block.py +++ b/cacheflow/block.py @@ -1,11 +1,17 @@ +"""Token blocks.""" from typing import List from cacheflow.utils import Device -BLANK_TOKEN_ID = -1 +_BLANK_TOKEN_ID = -1 class LogicalTokenBlock: + """A block that stores a contiguous chunk of tokens from left to right. + + Logical blocks are used to represent the states of the corresponding + physical blocks in the KV cache. + """ def __init__( self, @@ -15,7 +21,7 @@ class LogicalTokenBlock: self.block_number = block_number self.block_size = block_size - self.token_ids = [BLANK_TOKEN_ID] * block_size + self.token_ids = [_BLANK_TOKEN_ID] * block_size self.num_tokens = 0 def is_empty(self) -> bool: @@ -41,6 +47,7 @@ class LogicalTokenBlock: class PhysicalTokenBlock: + """Represents the state of a block in the KV cache.""" def __init__( self, diff --git a/cacheflow/core/block_manager.py b/cacheflow/core/block_manager.py index 9e64e1d6bfb4..2d471e0b024f 100644 --- a/cacheflow/core/block_manager.py +++ b/cacheflow/core/block_manager.py @@ -1,13 +1,18 @@ +"""A block manager that manages token blocks.""" from typing import Dict, List, Optional, Set, Tuple from cacheflow.block import PhysicalTokenBlock -from cacheflow.sequence import Sequence -from cacheflow.sequence import SequenceGroup -from cacheflow.sequence import SequenceStatus +from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus from cacheflow.utils import Device class BlockAllocator: + """Manages free physical token blocks for a device. + + The allocator maintains a list of free blocks and allocates a block when + requested. When a block is freed, its reference count is decremented. If + the reference count becomes zero, the block is added back to the free list. + """ def __init__( self, @@ -20,24 +25,22 @@ class BlockAllocator: self.num_blocks = num_blocks # Initialize the free blocks. - # TODO(woosuk): Make this a priority queue. - self.free_blocks = [ - PhysicalTokenBlock(device=device, block_number=i, block_size=block_size) - for i in range(num_blocks) - ] + self.free_blocks: List[PhysicalTokenBlock] = [] + for i in range(num_blocks): + block = PhysicalTokenBlock( + device=device, block_number=i, block_size=block_size) + self.free_blocks.append(block) def allocate(self) -> PhysicalTokenBlock: if not self.free_blocks: - raise ValueError('Out of memory! ' - f'No more free blocks are available.') + raise ValueError("Out of memory! No free blocks are available.") block = self.free_blocks.pop() block.ref_count = 1 return block def free(self, block: PhysicalTokenBlock) -> None: if block.ref_count == 0: - raise ValueError('Double free! ' - f'The block {block} is already freed.') + raise ValueError(f"Double free! {block} is already freed.") block.ref_count -= 1 if block.ref_count == 0: self.free_blocks.append(block) @@ -51,6 +54,7 @@ BlockTable = List[PhysicalTokenBlock] class BlockSpaceManager: + """Manages the mapping between logical and physical token blocks.""" def __init__( self, @@ -66,9 +70,10 @@ class BlockSpaceManager: assert watermark >= 0.0 self.watermark_blocks = int(watermark * num_gpu_blocks) - self.gpu_allocator = BlockAllocator(Device.GPU, block_size, num_gpu_blocks) - self.cpu_allocator = BlockAllocator(Device.CPU, block_size, num_cpu_blocks) - + self.gpu_allocator = BlockAllocator(Device.GPU, block_size, + num_gpu_blocks) + self.cpu_allocator = BlockAllocator(Device.CPU, block_size, + num_cpu_blocks) # Mapping: seq_id -> BlockTable. self.block_tables: Dict[int, BlockTable] = {} diff --git a/cacheflow/core/server.py b/cacheflow/core/server.py index a35a27bcb129..46b75d8117f7 100644 --- a/cacheflow/core/server.py +++ b/cacheflow/core/server.py @@ -1,12 +1,12 @@ import argparse -from typing import List, Tuple, Optional import random +from typing import List, Optional, Tuple -import torch try: import ray except ImportError: ray = None +import torch from cacheflow.core.scheduler import Scheduler from cacheflow.frontend.simple_frontend import SimpleFrontend diff --git a/cacheflow/model_executor/layers/activation.py b/cacheflow/model_executor/layers/activation.py index c3267ebcb7d4..f82d57769fa8 100644 --- a/cacheflow/model_executor/layers/activation.py +++ b/cacheflow/model_executor/layers/activation.py @@ -1,3 +1,4 @@ +"""Custom activation functions.""" import torch import torch.nn as nn @@ -5,6 +6,10 @@ from cacheflow import activation_ops class SiluAndMul(nn.Module): + """An activation function for SwiGLU. + + The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[1] // 2. + """ def __init__(self): super().__init__() diff --git a/cacheflow/model_executor/layers/attention.py b/cacheflow/model_executor/layers/attention.py index 6d7b49edd931..82c9b1cdad66 100644 --- a/cacheflow/model_executor/layers/attention.py +++ b/cacheflow/model_executor/layers/attention.py @@ -1,3 +1,4 @@ +"""Multi-head attention.""" from typing import Optional import torch @@ -11,6 +12,32 @@ from cacheflow.model_executor.input_metadata import InputMetadata class GPTCacheFlowAttention(nn.Module): + """GPT-style multi-head attention. + + This class takes flattened 1D query, key, and value tensors as input. The + input 1D tensors can be split into three parts: the prompt tokens, the + generation tokens, and the paddings. + + |<------------------------------------- num_valid_tokens ------------------------------------->| + |<--------------- num_prompt_tokens -------------->|<------- num_generation_tokens (M) ------->| + |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--generation_0-->|...|<--generation_M-1-->|<--padding-->| + + The prompts might have different lengths, while the generation tokens always + have length 1. The paddings are appended to make the input length a multiple + of 8, which is desirable for Tensor Cores. + + The class does the following: + 1. Perform multi_query_kv_attention for the prompts. This operation does + not use the KV cache. + 2. Wait for the cache operations (e.g., swap, copy) to finish. The cache + operations are issued by the cache engine before executing the forward + pass of the model, and they are executed asynchronously. + 3. Reshape and store the input key and value tensors in the KV cache. + 4. Perform single_query_cached_kv_attention for the generation tokens. + This operation reads the previous key and value tensors from the KV + cache. + 5. Output a flattened 1D tensor. + """ def __init__(self, scale: float) -> None: super().__init__() @@ -157,7 +184,7 @@ class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention): torch_dtype = torch.get_default_dtype() cache = cache.to(torch_dtype) # Embedding size: [max_position, rotary_dim] - self.register_buffer('cos_sin_cache', cache, persistent=False) + self.register_buffer("cos_sin_cache", cache, persistent=False) def forward( self, diff --git a/cacheflow/model_executor/layers/layernorm.py b/cacheflow/model_executor/layers/layernorm.py index 37c41c5fce42..f6e52cced781 100644 --- a/cacheflow/model_executor/layers/layernorm.py +++ b/cacheflow/model_executor/layers/layernorm.py @@ -1,3 +1,4 @@ +"""Custom normalization layers.""" import torch import torch.nn as nn @@ -5,6 +6,11 @@ from cacheflow import layernorm_ops class RMSNorm(nn.Module): + """Root mean square normalization. + + Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight. + Refer to https://arxiv.org/abs/1910.07467 + """ def __init__( self, diff --git a/cacheflow/model_executor/layers/sampler.py b/cacheflow/model_executor/layers/sampler.py index 05703bce4cc8..7321dbf35e4d 100644 --- a/cacheflow/model_executor/layers/sampler.py +++ b/cacheflow/model_executor/layers/sampler.py @@ -1,3 +1,4 @@ +"""A layer that samples the next tokens from the model's outputs.""" from typing import Dict, List, Tuple import numpy as np @@ -12,6 +13,19 @@ from cacheflow.sequence import SequenceOutputs class Sampler(nn.Module): + """Samples the next tokens from the model's outputs. + + This layer does the following: + 1. Discard the hidden states that are not used for sampling (i.e., all + tokens except the final one in each prompt). + 2. Compute the logits for the next tokens. + 3. Apply presence and frequency penalties. + 4. Apply temperature scaling. + 5. Apply top-p and top-k truncation. + 6. Sample the next tokens. + Here, each sequence group within the batch can have different sampling + parameters (e.g., sampling method, temperature, top-p, top-k, etc.). + """ def __init__(self, vocab_size: int) -> None: super().__init__() diff --git a/cacheflow/model_executor/model_loader.py b/cacheflow/model_executor/model_loader.py index 5598309e8a4c..a89fe7584f3b 100644 --- a/cacheflow/model_executor/model_loader.py +++ b/cacheflow/model_executor/model_loader.py @@ -1,9 +1,9 @@ +"""Utilities for selecting and loading models.""" from typing import Optional import torch import torch.nn as nn -from transformers import AutoConfig -from transformers import PretrainedConfig +from transformers import AutoConfig, PretrainedConfig from cacheflow.model_executor.memory_analyzer import ( CacheFlowMemoryAnalyzer, GPT2MemoryAnalyzer, GPTNeoXMemoryAnalyzer, diff --git a/cacheflow/model_executor/models/gpt2.py b/cacheflow/model_executor/models/gpt2.py index 97673ca450f6..4810f38562d8 100644 --- a/cacheflow/model_executor/models/gpt2.py +++ b/cacheflow/model_executor/models/gpt2.py @@ -15,7 +15,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""1D GPT-2 model compatible with HuggingFace weights.""" +"""Inference-only GPT-2 model compatible with HuggingFace weights. + +The input of the model is flattened to a 1D tensor of tokens. The model uses +InputMetadata to extract the original 2D shape of the input. +""" from typing import Dict, List, Optional, Tuple import torch diff --git a/cacheflow/model_executor/models/gpt_neox.py b/cacheflow/model_executor/models/gpt_neox.py index 8b074aecca3d..916ba3f0ea27 100644 --- a/cacheflow/model_executor/models/gpt_neox.py +++ b/cacheflow/model_executor/models/gpt_neox.py @@ -14,7 +14,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""1D GPT-NeoX model compatible with HuggingFace weights.""" +"""Inference-only GPT-NeoX model compatible with HuggingFace weights. + +The input of the model is flattened to a 1D tensor of tokens. The model uses +InputMetadata to extract the original 2D shape of the input. +""" from typing import Dict, List, Optional, Tuple import torch @@ -79,6 +83,7 @@ class GPTNeoXAttention(nn.Module): class GPTNeoXMLP(nn.Module): + def __init__(self, config: GPTNeoXConfig): super().__init__() self.dense_h_to_4h = ColumnParallelLinear(config.hidden_size, diff --git a/cacheflow/model_executor/models/llama.py b/cacheflow/model_executor/models/llama.py index f42210c7e035..04699cad4db9 100644 --- a/cacheflow/model_executor/models/llama.py +++ b/cacheflow/model_executor/models/llama.py @@ -19,7 +19,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""1D LLaMA model compatible with HuggingFace weights.""" +"""Inference-only LLaMA model compatible with HuggingFace weights. + +The input of the model is flattened to a 1D tensor of tokens. The model uses +InputMetadata to extract the original 2D shape of the input. +""" from typing import Dict, List, Optional, Tuple import torch diff --git a/cacheflow/model_executor/models/opt.py b/cacheflow/model_executor/models/opt.py index ddb09d42e7bc..e51abe84234e 100644 --- a/cacheflow/model_executor/models/opt.py +++ b/cacheflow/model_executor/models/opt.py @@ -14,7 +14,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""1D OPT model compatible with HuggingFace weights.""" +"""Inference-only OPT model compatible with HuggingFace weights. + +The input of the model is flattened to a 1D tensor of tokens. The model uses +InputMetadata to extract the original 2D shape of the input. +""" from typing import Dict, List, Optional, Tuple import torch diff --git a/cacheflow/model_executor/utils.py b/cacheflow/model_executor/utils.py index 72f76a25c87e..ae7810ca5c95 100644 --- a/cacheflow/model_executor/utils.py +++ b/cacheflow/model_executor/utils.py @@ -1,3 +1,4 @@ +"""Utils for model executor.""" import random from typing import Union @@ -9,11 +10,11 @@ from cacheflow.model_executor.parallel_utils.tensor_parallel import model_parall _STR_DTYPE_TO_TORCH_DTYPE = { - 'half': torch.half, - 'float': torch.float, - 'float16': torch.float16, - 'float32': torch.float32, - 'bfloat16': torch.bfloat16, + "half": torch.half, + "float": torch.float, + "float16": torch.float16, + "float32": torch.float32, + "bfloat16": torch.bfloat16, } diff --git a/cacheflow/model_executor/weight_utils.py b/cacheflow/model_executor/weight_utils.py index 796bfffbb70f..cef67fde587f 100644 --- a/cacheflow/model_executor/weight_utils.py +++ b/cacheflow/model_executor/weight_utils.py @@ -1,3 +1,4 @@ +"""Utilities for downloading and initializing model weights.""" import filelock import glob import json @@ -106,5 +107,12 @@ def initialize_dummy_weights( low: float = -1e-3, high: float = 1e-3, ) -> None: + """Initialize model weights with random values. + + The model weights must be randomly initialized for accurate performance + measurements. Additionally, the model weights should not cause NaNs in the + forward pass. We empirically found that initializing the weights with + values between -1e-3 and 1e-3 works well for most models. + """ for param in model.state_dict().values(): param.data.uniform_(low, high) diff --git a/cacheflow/sampling_params.py b/cacheflow/sampling_params.py index 0a2fa0196daf..d140521c919b 100644 --- a/cacheflow/sampling_params.py +++ b/cacheflow/sampling_params.py @@ -1,7 +1,37 @@ +"""Sampling parameters for text generation.""" from typing import Set class SamplingParams: + """Sampling parameters for text generation. + + Overall, we follow the sampling parameters from the OpenAI text completion + API (https://platform.openai.com/docs/api-reference/completions/create). + In addition, we support beam search, which is not supported by OpenAI. + + Args: + n: Number of output sequences to generate from the given prompt. This is + regarded as the beam width when using beam search. + presence_penalty: Float that penalizes new tokens based on whether they + appear in the generated text so far. Values > 0 encourage the model + to use new tokens, while values < 0 encourage the model to repeat + tokens. + frequency_penalty: Float that penalizes new tokens based on their + frequency in the generated text so far. Values > 0 encourage the + model to use new tokens, while values < 0 encourage the model to + repeat tokens. + temperature: Float that controls the randomness of the sampling. Lower + values make the model more deterministic, while higher values make + the model more random. Zero means greedy sampling. + top_p: Float that controls the cumulative probability of the top tokens + to consider. Must be in (0, 1]. Set to 1 to consider all tokens. + top_k: Integer that controls the number of top tokens to consider. Set + to -1 to consider all tokens. + use_beam_search: Whether to use beam search instead of sampling. + stop_token_ids: Set of token IDs that indicate the end of a sequence. + max_tokens: Maximum number of tokens to generate per output sequence. + logprobs: Number of log probabilities to return per output token. + """ def __init__( self, diff --git a/cacheflow/worker/cache_engine.py b/cacheflow/worker/cache_engine.py index addde3883b69..c044a06e6def 100644 --- a/cacheflow/worker/cache_engine.py +++ b/cacheflow/worker/cache_engine.py @@ -1,12 +1,20 @@ +"""CacheEngine class for managing the KV cache.""" from typing import Dict, List, Tuple import torch + from cacheflow import cache_ops KVCache = Tuple[torch.Tensor, torch.Tensor] class CacheEngine: + """Manages the KV cache. + + This class is responsible for initializing and managing the GPU and CPU KV + caches. It also provides methods for performing KV cache operations, such + as swapping and copying. + """ def __init__( self, diff --git a/cacheflow/worker/worker.py b/cacheflow/worker/worker.py index 90d5d7af97ac..c98cf2257255 100644 --- a/cacheflow/worker/worker.py +++ b/cacheflow/worker/worker.py @@ -1,3 +1,4 @@ +"""A GPU worker class.""" from typing import Dict, List, Optional, Tuple import torch @@ -14,6 +15,12 @@ from cacheflow.worker.cache_engine import CacheEngine class Worker: + """A worker class that executes (a partition of) the model on a GPU. + + Each worker is associated with a single GPU. The worker is responsible for + maintaining the KV cache and executing the model on the GPU. In case of + distributed inference, each worker is assigned a partition of the model. + """ def __init__( self,