Add docstrings to some modules and classes (#100)

This commit is contained in:
Woosuk Kwon 2023-05-14 22:32:38 -07:00 committed by GitHub
parent 667ba3995c
commit b322fd1607
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 166 additions and 31 deletions

View File

@ -1,11 +1,17 @@
"""Token blocks."""
from typing import List
from cacheflow.utils import Device
BLANK_TOKEN_ID = -1
_BLANK_TOKEN_ID = -1
class LogicalTokenBlock:
"""A block that stores a contiguous chunk of tokens from left to right.
Logical blocks are used to represent the states of the corresponding
physical blocks in the KV cache.
"""
def __init__(
self,
@ -15,7 +21,7 @@ class LogicalTokenBlock:
self.block_number = block_number
self.block_size = block_size
self.token_ids = [BLANK_TOKEN_ID] * block_size
self.token_ids = [_BLANK_TOKEN_ID] * block_size
self.num_tokens = 0
def is_empty(self) -> bool:
@ -41,6 +47,7 @@ class LogicalTokenBlock:
class PhysicalTokenBlock:
"""Represents the state of a block in the KV cache."""
def __init__(
self,

View File

@ -1,13 +1,18 @@
"""A block manager that manages token blocks."""
from typing import Dict, List, Optional, Set, Tuple
from cacheflow.block import PhysicalTokenBlock
from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup
from cacheflow.sequence import SequenceStatus
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
from cacheflow.utils import Device
class BlockAllocator:
"""Manages free physical token blocks for a device.
The allocator maintains a list of free blocks and allocates a block when
requested. When a block is freed, its reference count is decremented. If
the reference count becomes zero, the block is added back to the free list.
"""
def __init__(
self,
@ -20,24 +25,22 @@ class BlockAllocator:
self.num_blocks = num_blocks
# Initialize the free blocks.
# TODO(woosuk): Make this a priority queue.
self.free_blocks = [
PhysicalTokenBlock(device=device, block_number=i, block_size=block_size)
for i in range(num_blocks)
]
self.free_blocks: List[PhysicalTokenBlock] = []
for i in range(num_blocks):
block = PhysicalTokenBlock(
device=device, block_number=i, block_size=block_size)
self.free_blocks.append(block)
def allocate(self) -> PhysicalTokenBlock:
if not self.free_blocks:
raise ValueError('Out of memory! '
f'No more free blocks are available.')
raise ValueError("Out of memory! No free blocks are available.")
block = self.free_blocks.pop()
block.ref_count = 1
return block
def free(self, block: PhysicalTokenBlock) -> None:
if block.ref_count == 0:
raise ValueError('Double free! '
f'The block {block} is already freed.')
raise ValueError(f"Double free! {block} is already freed.")
block.ref_count -= 1
if block.ref_count == 0:
self.free_blocks.append(block)
@ -51,6 +54,7 @@ BlockTable = List[PhysicalTokenBlock]
class BlockSpaceManager:
"""Manages the mapping between logical and physical token blocks."""
def __init__(
self,
@ -66,9 +70,10 @@ class BlockSpaceManager:
assert watermark >= 0.0
self.watermark_blocks = int(watermark * num_gpu_blocks)
self.gpu_allocator = BlockAllocator(Device.GPU, block_size, num_gpu_blocks)
self.cpu_allocator = BlockAllocator(Device.CPU, block_size, num_cpu_blocks)
self.gpu_allocator = BlockAllocator(Device.GPU, block_size,
num_gpu_blocks)
self.cpu_allocator = BlockAllocator(Device.CPU, block_size,
num_cpu_blocks)
# Mapping: seq_id -> BlockTable.
self.block_tables: Dict[int, BlockTable] = {}

View File

@ -1,12 +1,12 @@
import argparse
from typing import List, Tuple, Optional
import random
from typing import List, Optional, Tuple
import torch
try:
import ray
except ImportError:
ray = None
import torch
from cacheflow.core.scheduler import Scheduler
from cacheflow.frontend.simple_frontend import SimpleFrontend

View File

@ -1,3 +1,4 @@
"""Custom activation functions."""
import torch
import torch.nn as nn
@ -5,6 +6,10 @@ from cacheflow import activation_ops
class SiluAndMul(nn.Module):
"""An activation function for SwiGLU.
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[1] // 2.
"""
def __init__(self):
super().__init__()

View File

@ -1,3 +1,4 @@
"""Multi-head attention."""
from typing import Optional
import torch
@ -11,6 +12,32 @@ from cacheflow.model_executor.input_metadata import InputMetadata
class GPTCacheFlowAttention(nn.Module):
"""GPT-style multi-head attention.
This class takes flattened 1D query, key, and value tensors as input. The
input 1D tensors can be split into three parts: the prompt tokens, the
generation tokens, and the paddings.
|<------------------------------------- num_valid_tokens ------------------------------------->|
|<--------------- num_prompt_tokens -------------->|<------- num_generation_tokens (M) ------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
The prompts might have different lengths, while the generation tokens always
have length 1. The paddings are appended to make the input length a multiple
of 8, which is desirable for Tensor Cores.
The class does the following:
1. Perform multi_query_kv_attention for the prompts. This operation does
not use the KV cache.
2. Wait for the cache operations (e.g., swap, copy) to finish. The cache
operations are issued by the cache engine before executing the forward
pass of the model, and they are executed asynchronously.
3. Reshape and store the input key and value tensors in the KV cache.
4. Perform single_query_cached_kv_attention for the generation tokens.
This operation reads the previous key and value tensors from the KV
cache.
5. Output a flattened 1D tensor.
"""
def __init__(self, scale: float) -> None:
super().__init__()
@ -157,7 +184,7 @@ class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
torch_dtype = torch.get_default_dtype()
cache = cache.to(torch_dtype)
# Embedding size: [max_position, rotary_dim]
self.register_buffer('cos_sin_cache', cache, persistent=False)
self.register_buffer("cos_sin_cache", cache, persistent=False)
def forward(
self,

View File

@ -1,3 +1,4 @@
"""Custom normalization layers."""
import torch
import torch.nn as nn
@ -5,6 +6,11 @@ from cacheflow import layernorm_ops
class RMSNorm(nn.Module):
"""Root mean square normalization.
Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
Refer to https://arxiv.org/abs/1910.07467
"""
def __init__(
self,

View File

@ -1,3 +1,4 @@
"""A layer that samples the next tokens from the model's outputs."""
from typing import Dict, List, Tuple
import numpy as np
@ -12,6 +13,19 @@ from cacheflow.sequence import SequenceOutputs
class Sampler(nn.Module):
"""Samples the next tokens from the model's outputs.
This layer does the following:
1. Discard the hidden states that are not used for sampling (i.e., all
tokens except the final one in each prompt).
2. Compute the logits for the next tokens.
3. Apply presence and frequency penalties.
4. Apply temperature scaling.
5. Apply top-p and top-k truncation.
6. Sample the next tokens.
Here, each sequence group within the batch can have different sampling
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
"""
def __init__(self, vocab_size: int) -> None:
super().__init__()

View File

@ -1,9 +1,9 @@
"""Utilities for selecting and loading models."""
from typing import Optional
import torch
import torch.nn as nn
from transformers import AutoConfig
from transformers import PretrainedConfig
from transformers import AutoConfig, PretrainedConfig
from cacheflow.model_executor.memory_analyzer import (
CacheFlowMemoryAnalyzer, GPT2MemoryAnalyzer, GPTNeoXMemoryAnalyzer,

View File

@ -15,7 +15,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""1D GPT-2 model compatible with HuggingFace weights."""
"""Inference-only GPT-2 model compatible with HuggingFace weights.
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import Dict, List, Optional, Tuple
import torch

View File

@ -14,7 +14,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""1D GPT-NeoX model compatible with HuggingFace weights."""
"""Inference-only GPT-NeoX model compatible with HuggingFace weights.
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import Dict, List, Optional, Tuple
import torch
@ -79,6 +83,7 @@ class GPTNeoXAttention(nn.Module):
class GPTNeoXMLP(nn.Module):
def __init__(self, config: GPTNeoXConfig):
super().__init__()
self.dense_h_to_4h = ColumnParallelLinear(config.hidden_size,

View File

@ -19,7 +19,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""1D LLaMA model compatible with HuggingFace weights."""
"""Inference-only LLaMA model compatible with HuggingFace weights.
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import Dict, List, Optional, Tuple
import torch

View File

@ -14,7 +14,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""1D OPT model compatible with HuggingFace weights."""
"""Inference-only OPT model compatible with HuggingFace weights.
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import Dict, List, Optional, Tuple
import torch

View File

@ -1,3 +1,4 @@
"""Utils for model executor."""
import random
from typing import Union
@ -9,11 +10,11 @@ from cacheflow.model_executor.parallel_utils.tensor_parallel import model_parall
_STR_DTYPE_TO_TORCH_DTYPE = {
'half': torch.half,
'float': torch.float,
'float16': torch.float16,
'float32': torch.float32,
'bfloat16': torch.bfloat16,
"half": torch.half,
"float": torch.float,
"float16": torch.float16,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}

View File

@ -1,3 +1,4 @@
"""Utilities for downloading and initializing model weights."""
import filelock
import glob
import json
@ -106,5 +107,12 @@ def initialize_dummy_weights(
low: float = -1e-3,
high: float = 1e-3,
) -> None:
"""Initialize model weights with random values.
The model weights must be randomly initialized for accurate performance
measurements. Additionally, the model weights should not cause NaNs in the
forward pass. We empirically found that initializing the weights with
values between -1e-3 and 1e-3 works well for most models.
"""
for param in model.state_dict().values():
param.data.uniform_(low, high)

View File

@ -1,7 +1,37 @@
"""Sampling parameters for text generation."""
from typing import Set
class SamplingParams:
"""Sampling parameters for text generation.
Overall, we follow the sampling parameters from the OpenAI text completion
API (https://platform.openai.com/docs/api-reference/completions/create).
In addition, we support beam search, which is not supported by OpenAI.
Args:
n: Number of output sequences to generate from the given prompt. This is
regarded as the beam width when using beam search.
presence_penalty: Float that penalizes new tokens based on whether they
appear in the generated text so far. Values > 0 encourage the model
to use new tokens, while values < 0 encourage the model to repeat
tokens.
frequency_penalty: Float that penalizes new tokens based on their
frequency in the generated text so far. Values > 0 encourage the
model to use new tokens, while values < 0 encourage the model to
repeat tokens.
temperature: Float that controls the randomness of the sampling. Lower
values make the model more deterministic, while higher values make
the model more random. Zero means greedy sampling.
top_p: Float that controls the cumulative probability of the top tokens
to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
top_k: Integer that controls the number of top tokens to consider. Set
to -1 to consider all tokens.
use_beam_search: Whether to use beam search instead of sampling.
stop_token_ids: Set of token IDs that indicate the end of a sequence.
max_tokens: Maximum number of tokens to generate per output sequence.
logprobs: Number of log probabilities to return per output token.
"""
def __init__(
self,

View File

@ -1,12 +1,20 @@
"""CacheEngine class for managing the KV cache."""
from typing import Dict, List, Tuple
import torch
from cacheflow import cache_ops
KVCache = Tuple[torch.Tensor, torch.Tensor]
class CacheEngine:
"""Manages the KV cache.
This class is responsible for initializing and managing the GPU and CPU KV
caches. It also provides methods for performing KV cache operations, such
as swapping and copying.
"""
def __init__(
self,

View File

@ -1,3 +1,4 @@
"""A GPU worker class."""
from typing import Dict, List, Optional, Tuple
import torch
@ -14,6 +15,12 @@ from cacheflow.worker.cache_engine import CacheEngine
class Worker:
"""A worker class that executes (a partition of) the model on a GPU.
Each worker is associated with a single GPU. The worker is responsible for
maintaining the KV cache and executing the model on the GPU. In case of
distributed inference, each worker is assigned a partition of the model.
"""
def __init__(
self,