mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:35:50 +08:00
Add docstrings to some modules and classes (#100)
This commit is contained in:
parent
667ba3995c
commit
b322fd1607
@ -1,11 +1,17 @@
|
|||||||
|
"""Token blocks."""
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from cacheflow.utils import Device
|
from cacheflow.utils import Device
|
||||||
|
|
||||||
BLANK_TOKEN_ID = -1
|
_BLANK_TOKEN_ID = -1
|
||||||
|
|
||||||
|
|
||||||
class LogicalTokenBlock:
|
class LogicalTokenBlock:
|
||||||
|
"""A block that stores a contiguous chunk of tokens from left to right.
|
||||||
|
|
||||||
|
Logical blocks are used to represent the states of the corresponding
|
||||||
|
physical blocks in the KV cache.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -15,7 +21,7 @@ class LogicalTokenBlock:
|
|||||||
self.block_number = block_number
|
self.block_number = block_number
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
|
|
||||||
self.token_ids = [BLANK_TOKEN_ID] * block_size
|
self.token_ids = [_BLANK_TOKEN_ID] * block_size
|
||||||
self.num_tokens = 0
|
self.num_tokens = 0
|
||||||
|
|
||||||
def is_empty(self) -> bool:
|
def is_empty(self) -> bool:
|
||||||
@ -41,6 +47,7 @@ class LogicalTokenBlock:
|
|||||||
|
|
||||||
|
|
||||||
class PhysicalTokenBlock:
|
class PhysicalTokenBlock:
|
||||||
|
"""Represents the state of a block in the KV cache."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -1,13 +1,18 @@
|
|||||||
|
"""A block manager that manages token blocks."""
|
||||||
from typing import Dict, List, Optional, Set, Tuple
|
from typing import Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
from cacheflow.block import PhysicalTokenBlock
|
from cacheflow.block import PhysicalTokenBlock
|
||||||
from cacheflow.sequence import Sequence
|
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||||
from cacheflow.sequence import SequenceGroup
|
|
||||||
from cacheflow.sequence import SequenceStatus
|
|
||||||
from cacheflow.utils import Device
|
from cacheflow.utils import Device
|
||||||
|
|
||||||
|
|
||||||
class BlockAllocator:
|
class BlockAllocator:
|
||||||
|
"""Manages free physical token blocks for a device.
|
||||||
|
|
||||||
|
The allocator maintains a list of free blocks and allocates a block when
|
||||||
|
requested. When a block is freed, its reference count is decremented. If
|
||||||
|
the reference count becomes zero, the block is added back to the free list.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -20,24 +25,22 @@ class BlockAllocator:
|
|||||||
self.num_blocks = num_blocks
|
self.num_blocks = num_blocks
|
||||||
|
|
||||||
# Initialize the free blocks.
|
# Initialize the free blocks.
|
||||||
# TODO(woosuk): Make this a priority queue.
|
self.free_blocks: List[PhysicalTokenBlock] = []
|
||||||
self.free_blocks = [
|
for i in range(num_blocks):
|
||||||
PhysicalTokenBlock(device=device, block_number=i, block_size=block_size)
|
block = PhysicalTokenBlock(
|
||||||
for i in range(num_blocks)
|
device=device, block_number=i, block_size=block_size)
|
||||||
]
|
self.free_blocks.append(block)
|
||||||
|
|
||||||
def allocate(self) -> PhysicalTokenBlock:
|
def allocate(self) -> PhysicalTokenBlock:
|
||||||
if not self.free_blocks:
|
if not self.free_blocks:
|
||||||
raise ValueError('Out of memory! '
|
raise ValueError("Out of memory! No free blocks are available.")
|
||||||
f'No more free blocks are available.')
|
|
||||||
block = self.free_blocks.pop()
|
block = self.free_blocks.pop()
|
||||||
block.ref_count = 1
|
block.ref_count = 1
|
||||||
return block
|
return block
|
||||||
|
|
||||||
def free(self, block: PhysicalTokenBlock) -> None:
|
def free(self, block: PhysicalTokenBlock) -> None:
|
||||||
if block.ref_count == 0:
|
if block.ref_count == 0:
|
||||||
raise ValueError('Double free! '
|
raise ValueError(f"Double free! {block} is already freed.")
|
||||||
f'The block {block} is already freed.')
|
|
||||||
block.ref_count -= 1
|
block.ref_count -= 1
|
||||||
if block.ref_count == 0:
|
if block.ref_count == 0:
|
||||||
self.free_blocks.append(block)
|
self.free_blocks.append(block)
|
||||||
@ -51,6 +54,7 @@ BlockTable = List[PhysicalTokenBlock]
|
|||||||
|
|
||||||
|
|
||||||
class BlockSpaceManager:
|
class BlockSpaceManager:
|
||||||
|
"""Manages the mapping between logical and physical token blocks."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -66,9 +70,10 @@ class BlockSpaceManager:
|
|||||||
assert watermark >= 0.0
|
assert watermark >= 0.0
|
||||||
|
|
||||||
self.watermark_blocks = int(watermark * num_gpu_blocks)
|
self.watermark_blocks = int(watermark * num_gpu_blocks)
|
||||||
self.gpu_allocator = BlockAllocator(Device.GPU, block_size, num_gpu_blocks)
|
self.gpu_allocator = BlockAllocator(Device.GPU, block_size,
|
||||||
self.cpu_allocator = BlockAllocator(Device.CPU, block_size, num_cpu_blocks)
|
num_gpu_blocks)
|
||||||
|
self.cpu_allocator = BlockAllocator(Device.CPU, block_size,
|
||||||
|
num_cpu_blocks)
|
||||||
# Mapping: seq_id -> BlockTable.
|
# Mapping: seq_id -> BlockTable.
|
||||||
self.block_tables: Dict[int, BlockTable] = {}
|
self.block_tables: Dict[int, BlockTable] = {}
|
||||||
|
|
||||||
|
|||||||
@ -1,12 +1,12 @@
|
|||||||
import argparse
|
import argparse
|
||||||
from typing import List, Tuple, Optional
|
|
||||||
import random
|
import random
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
|
||||||
try:
|
try:
|
||||||
import ray
|
import ray
|
||||||
except ImportError:
|
except ImportError:
|
||||||
ray = None
|
ray = None
|
||||||
|
import torch
|
||||||
|
|
||||||
from cacheflow.core.scheduler import Scheduler
|
from cacheflow.core.scheduler import Scheduler
|
||||||
from cacheflow.frontend.simple_frontend import SimpleFrontend
|
from cacheflow.frontend.simple_frontend import SimpleFrontend
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
"""Custom activation functions."""
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
@ -5,6 +6,10 @@ from cacheflow import activation_ops
|
|||||||
|
|
||||||
|
|
||||||
class SiluAndMul(nn.Module):
|
class SiluAndMul(nn.Module):
|
||||||
|
"""An activation function for SwiGLU.
|
||||||
|
|
||||||
|
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[1] // 2.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
"""Multi-head attention."""
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -11,6 +12,32 @@ from cacheflow.model_executor.input_metadata import InputMetadata
|
|||||||
|
|
||||||
|
|
||||||
class GPTCacheFlowAttention(nn.Module):
|
class GPTCacheFlowAttention(nn.Module):
|
||||||
|
"""GPT-style multi-head attention.
|
||||||
|
|
||||||
|
This class takes flattened 1D query, key, and value tensors as input. The
|
||||||
|
input 1D tensors can be split into three parts: the prompt tokens, the
|
||||||
|
generation tokens, and the paddings.
|
||||||
|
|
||||||
|
|<------------------------------------- num_valid_tokens ------------------------------------->|
|
||||||
|
|<--------------- num_prompt_tokens -------------->|<------- num_generation_tokens (M) ------->|
|
||||||
|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
|
||||||
|
|
||||||
|
The prompts might have different lengths, while the generation tokens always
|
||||||
|
have length 1. The paddings are appended to make the input length a multiple
|
||||||
|
of 8, which is desirable for Tensor Cores.
|
||||||
|
|
||||||
|
The class does the following:
|
||||||
|
1. Perform multi_query_kv_attention for the prompts. This operation does
|
||||||
|
not use the KV cache.
|
||||||
|
2. Wait for the cache operations (e.g., swap, copy) to finish. The cache
|
||||||
|
operations are issued by the cache engine before executing the forward
|
||||||
|
pass of the model, and they are executed asynchronously.
|
||||||
|
3. Reshape and store the input key and value tensors in the KV cache.
|
||||||
|
4. Perform single_query_cached_kv_attention for the generation tokens.
|
||||||
|
This operation reads the previous key and value tensors from the KV
|
||||||
|
cache.
|
||||||
|
5. Output a flattened 1D tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, scale: float) -> None:
|
def __init__(self, scale: float) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -157,7 +184,7 @@ class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
|
|||||||
torch_dtype = torch.get_default_dtype()
|
torch_dtype = torch.get_default_dtype()
|
||||||
cache = cache.to(torch_dtype)
|
cache = cache.to(torch_dtype)
|
||||||
# Embedding size: [max_position, rotary_dim]
|
# Embedding size: [max_position, rotary_dim]
|
||||||
self.register_buffer('cos_sin_cache', cache, persistent=False)
|
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
"""Custom normalization layers."""
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
@ -5,6 +6,11 @@ from cacheflow import layernorm_ops
|
|||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
|
"""Root mean square normalization.
|
||||||
|
|
||||||
|
Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
|
||||||
|
Refer to https://arxiv.org/abs/1910.07467
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
"""A layer that samples the next tokens from the model's outputs."""
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -12,6 +13,19 @@ from cacheflow.sequence import SequenceOutputs
|
|||||||
|
|
||||||
|
|
||||||
class Sampler(nn.Module):
|
class Sampler(nn.Module):
|
||||||
|
"""Samples the next tokens from the model's outputs.
|
||||||
|
|
||||||
|
This layer does the following:
|
||||||
|
1. Discard the hidden states that are not used for sampling (i.e., all
|
||||||
|
tokens except the final one in each prompt).
|
||||||
|
2. Compute the logits for the next tokens.
|
||||||
|
3. Apply presence and frequency penalties.
|
||||||
|
4. Apply temperature scaling.
|
||||||
|
5. Apply top-p and top-k truncation.
|
||||||
|
6. Sample the next tokens.
|
||||||
|
Here, each sequence group within the batch can have different sampling
|
||||||
|
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, vocab_size: int) -> None:
|
def __init__(self, vocab_size: int) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
|
"""Utilities for selecting and loading models."""
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig, PretrainedConfig
|
||||||
from transformers import PretrainedConfig
|
|
||||||
|
|
||||||
from cacheflow.model_executor.memory_analyzer import (
|
from cacheflow.model_executor.memory_analyzer import (
|
||||||
CacheFlowMemoryAnalyzer, GPT2MemoryAnalyzer, GPTNeoXMemoryAnalyzer,
|
CacheFlowMemoryAnalyzer, GPT2MemoryAnalyzer, GPTNeoXMemoryAnalyzer,
|
||||||
|
|||||||
@ -15,7 +15,11 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""1D GPT-2 model compatible with HuggingFace weights."""
|
"""Inference-only GPT-2 model compatible with HuggingFace weights.
|
||||||
|
|
||||||
|
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||||
|
InputMetadata to extract the original 2D shape of the input.
|
||||||
|
"""
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@ -14,7 +14,11 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""1D GPT-NeoX model compatible with HuggingFace weights."""
|
"""Inference-only GPT-NeoX model compatible with HuggingFace weights.
|
||||||
|
|
||||||
|
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||||
|
InputMetadata to extract the original 2D shape of the input.
|
||||||
|
"""
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -79,6 +83,7 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class GPTNeoXMLP(nn.Module):
|
class GPTNeoXMLP(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config: GPTNeoXConfig):
|
def __init__(self, config: GPTNeoXConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dense_h_to_4h = ColumnParallelLinear(config.hidden_size,
|
self.dense_h_to_4h = ColumnParallelLinear(config.hidden_size,
|
||||||
|
|||||||
@ -19,7 +19,11 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""1D LLaMA model compatible with HuggingFace weights."""
|
"""Inference-only LLaMA model compatible with HuggingFace weights.
|
||||||
|
|
||||||
|
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||||
|
InputMetadata to extract the original 2D shape of the input.
|
||||||
|
"""
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@ -14,7 +14,11 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""1D OPT model compatible with HuggingFace weights."""
|
"""Inference-only OPT model compatible with HuggingFace weights.
|
||||||
|
|
||||||
|
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||||
|
InputMetadata to extract the original 2D shape of the input.
|
||||||
|
"""
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
"""Utils for model executor."""
|
||||||
import random
|
import random
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@ -9,11 +10,11 @@ from cacheflow.model_executor.parallel_utils.tensor_parallel import model_parall
|
|||||||
|
|
||||||
|
|
||||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||||
'half': torch.half,
|
"half": torch.half,
|
||||||
'float': torch.float,
|
"float": torch.float,
|
||||||
'float16': torch.float16,
|
"float16": torch.float16,
|
||||||
'float32': torch.float32,
|
"float32": torch.float32,
|
||||||
'bfloat16': torch.bfloat16,
|
"bfloat16": torch.bfloat16,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
"""Utilities for downloading and initializing model weights."""
|
||||||
import filelock
|
import filelock
|
||||||
import glob
|
import glob
|
||||||
import json
|
import json
|
||||||
@ -106,5 +107,12 @@ def initialize_dummy_weights(
|
|||||||
low: float = -1e-3,
|
low: float = -1e-3,
|
||||||
high: float = 1e-3,
|
high: float = 1e-3,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Initialize model weights with random values.
|
||||||
|
|
||||||
|
The model weights must be randomly initialized for accurate performance
|
||||||
|
measurements. Additionally, the model weights should not cause NaNs in the
|
||||||
|
forward pass. We empirically found that initializing the weights with
|
||||||
|
values between -1e-3 and 1e-3 works well for most models.
|
||||||
|
"""
|
||||||
for param in model.state_dict().values():
|
for param in model.state_dict().values():
|
||||||
param.data.uniform_(low, high)
|
param.data.uniform_(low, high)
|
||||||
|
|||||||
@ -1,7 +1,37 @@
|
|||||||
|
"""Sampling parameters for text generation."""
|
||||||
from typing import Set
|
from typing import Set
|
||||||
|
|
||||||
|
|
||||||
class SamplingParams:
|
class SamplingParams:
|
||||||
|
"""Sampling parameters for text generation.
|
||||||
|
|
||||||
|
Overall, we follow the sampling parameters from the OpenAI text completion
|
||||||
|
API (https://platform.openai.com/docs/api-reference/completions/create).
|
||||||
|
In addition, we support beam search, which is not supported by OpenAI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n: Number of output sequences to generate from the given prompt. This is
|
||||||
|
regarded as the beam width when using beam search.
|
||||||
|
presence_penalty: Float that penalizes new tokens based on whether they
|
||||||
|
appear in the generated text so far. Values > 0 encourage the model
|
||||||
|
to use new tokens, while values < 0 encourage the model to repeat
|
||||||
|
tokens.
|
||||||
|
frequency_penalty: Float that penalizes new tokens based on their
|
||||||
|
frequency in the generated text so far. Values > 0 encourage the
|
||||||
|
model to use new tokens, while values < 0 encourage the model to
|
||||||
|
repeat tokens.
|
||||||
|
temperature: Float that controls the randomness of the sampling. Lower
|
||||||
|
values make the model more deterministic, while higher values make
|
||||||
|
the model more random. Zero means greedy sampling.
|
||||||
|
top_p: Float that controls the cumulative probability of the top tokens
|
||||||
|
to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
|
||||||
|
top_k: Integer that controls the number of top tokens to consider. Set
|
||||||
|
to -1 to consider all tokens.
|
||||||
|
use_beam_search: Whether to use beam search instead of sampling.
|
||||||
|
stop_token_ids: Set of token IDs that indicate the end of a sequence.
|
||||||
|
max_tokens: Maximum number of tokens to generate per output sequence.
|
||||||
|
logprobs: Number of log probabilities to return per output token.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -1,12 +1,20 @@
|
|||||||
|
"""CacheEngine class for managing the KV cache."""
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from cacheflow import cache_ops
|
from cacheflow import cache_ops
|
||||||
|
|
||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
class CacheEngine:
|
class CacheEngine:
|
||||||
|
"""Manages the KV cache.
|
||||||
|
|
||||||
|
This class is responsible for initializing and managing the GPU and CPU KV
|
||||||
|
caches. It also provides methods for performing KV cache operations, such
|
||||||
|
as swapping and copying.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
"""A GPU worker class."""
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -14,6 +15,12 @@ from cacheflow.worker.cache_engine import CacheEngine
|
|||||||
|
|
||||||
|
|
||||||
class Worker:
|
class Worker:
|
||||||
|
"""A worker class that executes (a partition of) the model on a GPU.
|
||||||
|
|
||||||
|
Each worker is associated with a single GPU. The worker is responsible for
|
||||||
|
maintaining the KV cache and executing the model on the GPU. In case of
|
||||||
|
distributed inference, each worker is assigned a partition of the model.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user