[Mistral] Mistral-7B-v0.1 support (#1196)

Co-authored-by: timlacroix <t@mistral.ai>
This commit is contained in:
Chris Bamford 2023-09-28 19:41:03 +02:00 committed by GitHub
parent 7bedab5748
commit bb1ba58f06
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 571 additions and 25 deletions

View File

@ -7,7 +7,7 @@ sentencepiece # Required for LLaMA tokenizer.
numpy
torch >= 2.0.0
transformers >= 4.33.1 # Required for Code Llama.
xformers >= 0.0.21
xformers >= 0.0.22
fastapi
uvicorn[standard]
pydantic < 2 # Required for OpenAI server.

View File

@ -187,10 +187,12 @@ class CacheConfig:
block_size: int,
gpu_memory_utilization: float,
swap_space: int,
sliding_window: Optional[int] = None,
) -> None:
self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization
self.swap_space_bytes = swap_space * _GB
self.sliding_window = sliding_window
self._verify_args()
# Will be set after profiling.

View File

@ -63,10 +63,18 @@ class BlockSpaceManager:
num_gpu_blocks: int,
num_cpu_blocks: int,
watermark: float = 0.01,
sliding_window: Optional[int] = None,
) -> None:
self.block_size = block_size
self.num_total_gpu_blocks = num_gpu_blocks
self.num_total_cpu_blocks = num_cpu_blocks
self.block_sliding_window = None
if sliding_window is not None:
assert sliding_window % block_size == 0, (sliding_window,
block_size)
self.block_sliding_window = sliding_window // block_size
self.watermark = watermark
assert watermark >= 0.0
@ -83,6 +91,9 @@ class BlockSpaceManager:
# the same prompt. This may not be true for preempted sequences.
seq = seq_group.get_seqs()[0]
num_required_blocks = len(seq.logical_token_blocks)
if self.block_sliding_window is not None:
num_required_blocks = min(num_required_blocks,
self.block_sliding_window)
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
# Use watermark to avoid frequent cache eviction.
return (num_free_gpu_blocks - num_required_blocks >=
@ -95,8 +106,12 @@ class BlockSpaceManager:
# Allocate new physical token blocks that will store the prompt tokens.
block_table: BlockTable = []
for _ in range(len(seq.logical_token_blocks)):
block = self.gpu_allocator.allocate()
for logical_idx in range(len(seq.logical_token_blocks)):
if (self.block_sliding_window is not None
and logical_idx >= self.block_sliding_window):
block = block_table[logical_idx % self.block_sliding_window]
else:
block = self.gpu_allocator.allocate()
# Set the reference counts of the token blocks.
block.ref_count = seq_group.num_seqs()
block_table.append(block)
@ -118,11 +133,17 @@ class BlockSpaceManager:
block_table = self.block_tables[seq.seq_id]
if len(block_table) < len(logical_blocks):
# The sequence has a new logical block.
# Allocate a new physical block.
block = self.gpu_allocator.allocate()
block_table.append(block)
return None
if (self.block_sliding_window
and len(block_table) >= self.block_sliding_window):
# re-use a block
block_table.append(block_table[len(block_table) %
self.block_sliding_window])
else:
# The sequence has a new logical block.
# Allocate a new physical block.
block = self.gpu_allocator.allocate()
block_table.append(block)
return None
# We want to append the token to the last physical block.
last_block = block_table[-1]
@ -154,9 +175,7 @@ class BlockSpaceManager:
for seq in seq_group.get_seqs():
if seq.is_finished():
continue
block_table = self.block_tables[seq.seq_id]
for block in block_table:
blocks.add(block)
blocks.update(self.block_tables[seq.seq_id])
return list(blocks)
def can_swap_in(self, seq_group: SequenceGroup) -> bool:
@ -224,7 +243,7 @@ class BlockSpaceManager:
return block_number_mapping
def _free_block_table(self, block_table: BlockTable) -> None:
for block in block_table:
for block in set(block_table):
if block.device == Device.GPU:
self.gpu_allocator.free(block)
else:

View File

@ -73,7 +73,7 @@ class Scheduler:
block_size=self.cache_config.block_size,
num_gpu_blocks=self.cache_config.num_gpu_blocks,
num_cpu_blocks=self.cache_config.num_cpu_blocks,
)
sliding_window=self.cache_config.sliding_window)
# TODO(zhuohan): Use deque instead of list for better performance.
# Sequence groups in the WAITING state.

View File

@ -176,9 +176,9 @@ class EngineArgs:
self.download_dir, self.load_format,
self.dtype, self.seed, self.revision,
self.max_model_len, self.quantization)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space)
cache_config = CacheConfig(
self.block_size, self.gpu_memory_utilization, self.swap_space,
getattr(model_config.hf_config, 'sliding_window', None))
parallel_config = ParallelConfig(self.pipeline_parallel_size,
self.tensor_parallel_size,
self.worker_use_ray)

View File

@ -86,6 +86,8 @@ class LLMEngine:
self.model_config = model_config
self.cache_config = cache_config
assert self.cache_config.sliding_window == getattr(
self.model_config.hf_config, "sliding_window", None)
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.log_stats = log_stats

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple
import torch
from xformers.ops import AttentionBias
@ -29,6 +29,7 @@ class InputMetadata:
context_lens: torch.Tensor,
max_context_len: int,
block_tables: torch.Tensor,
sliding_window: Optional[int] = None,
) -> None:
self.seq_groups = seq_groups
self.seq_data = seq_data
@ -38,6 +39,24 @@ class InputMetadata:
self.max_context_len = max_context_len
self.block_tables = block_tables
self.to_cache = None
if sliding_window is not None:
# We need to keep the positions of sliding windows within
# the key / value tables, this is helpful to know which
# elements we need to cache and where
to_cache, start_idx = [], 0
for prompt_len in self.prompt_lens:
to_cache.extend(
range(
start_idx + max(0, prompt_len - sliding_window),
start_idx + prompt_len,
))
start_idx += prompt_len
to_cache.extend(range(start_idx, slot_mapping.shape[0]))
self.to_cache = torch.tensor(to_cache,
dtype=torch.int32,
device=self.slot_mapping.device)
self.num_prompts = len(prompt_lens)
self.num_prompt_tokens = sum(prompt_lens)
self.num_generation_tokens = context_lens.shape[0]

View File

@ -58,12 +58,14 @@ class PagedAttention(nn.Module):
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None) -> None:
num_kv_heads: Optional[int] = None,
sliding_window: Optional[int] = None) -> None:
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
@ -86,6 +88,8 @@ class PagedAttention(nn.Module):
return
prompt_lens = input_metadata.prompt_lens
attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention(self.sliding_window)
input_metadata.attn_bias.append(attn_bias)
def multi_query_kv_attention(
@ -223,12 +227,20 @@ class PagedAttention(nn.Module):
if (num_valid_tokens > 0 and key_cache is not None
and value_cache is not None):
# The stride is 3 because the key and value are sliced from qkv.
key_to_cache = key[:num_valid_tokens]
value_to_cache = value[:num_valid_tokens]
slot_mapping = input_metadata.slot_mapping
if input_metadata.to_cache is not None:
key_to_cache = key_to_cache[input_metadata.to_cache]
value_to_cache = value_to_cache[input_metadata.to_cache]
slot_mapping = slot_mapping[input_metadata.to_cache]
cache_ops.reshape_and_cache(
key[:num_valid_tokens],
value[:num_valid_tokens],
key_to_cache,
value_to_cache,
key_cache,
value_cache,
input_metadata.slot_mapping,
slot_mapping,
)
if input_metadata.num_generation_tokens > 0:
@ -262,8 +274,13 @@ class PagedAttentionWithRoPE(PagedAttention):
num_kv_heads: Optional[int] = None,
is_neox_style: bool = True,
rope_scaling: Optional[Dict[str, Any]] = None,
sliding_window: Optional[int] = None,
) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads)
super().__init__(num_heads,
head_size,
scale,
num_kv_heads,
sliding_window=sliding_window)
if rope_scaling is None:
self.rotary_emb = RotaryEmbedding(head_size, rotary_dim,
max_position, base,

View File

@ -25,6 +25,7 @@ _MODEL_REGISTRY = {
"InternLMForCausalLM": InternLMForCausalLM,
"LlamaForCausalLM": LlamaForCausalLM,
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
"MistralForCausalLM": MistralForCausalLM,
"MPTForCausalLM": MPTForCausalLM,
"OPTForCausalLM": OPTForCausalLM,
"QWenLMHeadModel": QWenLMHeadModel,

View File

@ -12,6 +12,7 @@ from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.mpt import MPTForCausalLM
from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.models.qwen import QWenLMHeadModel
from vllm.model_executor.models.mistral import MistralForCausalLM
__all__ = [
"AquilaForCausalLM",
@ -28,4 +29,5 @@ __all__ = [
"MPTForCausalLM",
"OPTForCausalLM",
"QWenLMHeadModel",
"MistralForCausalLM",
]

View File

@ -0,0 +1,404 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights.
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import List, Optional, Tuple
import torch
from torch import nn
from vllm.transformers_utils.configs.mistral import MistralConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.quantized_linear import ParallelLinear
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding)
from vllm.model_executor.quantization_utils import QuantizationConfig
from vllm.model_executor.weight_utils import (
convert_pyslice_to_tensor, hf_model_weights_iterator,
load_tensor_parallel_weights, load_padded_tensor_parallel_vocab)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class MistralMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.gate_up_proj = ParallelLinear.column(hidden_size,
2 * intermediate_size,
bias=False,
gather_output=False,
perform_initialization=False,
quant_config=quant_config)
self.down_proj = ParallelLinear.row(intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False,
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class MistralAttention(nn.Module):
def __init__(self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
quant_config: Optional[QuantizationConfig] = None,
sliding_window: Optional[int] = None) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
assert self.total_num_kv_heads % tp_size == 0
self.num_kv_heads = self.total_num_kv_heads // tp_size
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.sliding_window = sliding_window
self.qkv_proj = ParallelLinear.column(
hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_dim,
bias=False,
gather_output=False,
perform_initialization=False,
quant_config=quant_config,
)
self.o_proj = ParallelLinear.row(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False,
quant_config=quant_config,
)
self.attn = PagedAttentionWithRoPE(self.num_heads,
self.head_dim,
self.scaling,
base=self.rope_theta,
max_position=max_position,
rotary_dim=self.head_dim,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
k_cache, v_cache = kv_cache
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
input_metadata, cache_event)
output, _ = self.o_proj(attn_output)
return output
class MistralDecoderLayer(nn.Module):
def __init__(
self,
config: MistralConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
self.self_attn = MistralAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
quant_config=quant_config,
sliding_window=config.sliding_window)
self.mlp = MistralMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class MistralModel(nn.Module):
def __init__(
self,
config: MistralConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.embed_tokens = VocabParallelEmbedding(
vocab_size, config.hidden_size, perform_initialization=False)
self.layers = nn.ModuleList([
MistralDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
for i in range(len(self.layers)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.layers[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.norm(hidden_states)
return hidden_states
class MistralForCausalLM(nn.Module):
def __init__(
self,
config: MistralConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = MistralModel(config, quant_config)
vocab_size = ((config.vocab_size + 63) // 64) * 64
# NOTE: The LM head is not quantized.
self.lm_head = ParallelLinear.column(config.hidden_size,
vocab_size,
bias=False,
gather_output=False,
perform_initialization=False,
quant_config=None)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
return next_tokens
_column_parallel_layers = []
_row_parallel_layers = ["o_proj", "down_proj"]
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
if self.quant_config is None:
weight_suffixes = ["weight"]
else:
weight_suffixes = self.quant_config.get_tp_tensor_names()
column_parallel_weights: List[str] = []
for layer in self._column_parallel_layers:
for suffix in weight_suffixes:
column_parallel_weights.append(f"{layer}.{suffix}")
row_parallel_weights: List[str] = []
for layer in self._row_parallel_layers:
for suffix in weight_suffixes:
row_parallel_weights.append(f"{layer}.{suffix}")
tp_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
q_proj_shard_size = (self.config.hidden_size // tp_size)
kv_proj_shard_size = (self.config.hidden_size //
self.config.num_attention_heads *
self.config.num_key_value_heads // tp_size)
attention_weight_specs = [
# (weight_name, shard_size, offset)
("q_proj", q_proj_shard_size, 0),
("k_proj", kv_proj_shard_size, q_proj_shard_size),
("v_proj", kv_proj_shard_size,
q_proj_shard_size + kv_proj_shard_size),
]
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name:
continue
is_packed = False
is_transposed = False
if self.quant_config is not None:
is_packed = self.quant_config.is_packed(name)
is_transposed = self.quant_config.is_transposed(name)
if is_transposed:
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
loaded_weight = loaded_weight.T
is_attention_weight = False
for weight_name, shard_size, offset in attention_weight_specs:
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "qkv_proj")]
if is_transposed:
param = param.T
if is_packed:
shard_size //= self.quant_config.pack_factor
offset //= self.quant_config.pack_factor
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[offset:offset + shard_size]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_attention_weight = True
break
if is_attention_weight:
continue
is_gate_up_weight = False
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "gate_up_proj")]
if is_transposed:
param = param.T
shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
break
if is_gate_up_weight:
continue
param = state_dict[name]
if is_transposed:
param = param.T
if "embed_tokens" in name or "lm_head" in name:
load_padded_tensor_parallel_vocab(param, loaded_weight,
tensor_model_parallel_rank)
continue
load_tensor_parallel_weights(param, loaded_weight, name,
column_parallel_weights,
row_parallel_weights,
tensor_model_parallel_rank)

View File

@ -0,0 +1,66 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mistral-7B-v0.1 configuration"""
from transformers.configuration_utils import PretrainedConfig
class MistralConfig(PretrainedConfig):
model_type = "mistral"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=4096 * 32,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_theta=10000.0,
sliding_window=4096,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

View File

@ -137,8 +137,10 @@ class Worker:
self.cache_config = cache_config
self.block_size = cache_config.block_size
_check_if_can_support_max_seq_len(self.scheduler_config.max_model_len,
self.block_size)
max_seq_len = min(self.scheduler_config.max_model_len,
cache_config.sliding_window or float("inf"))
_check_if_can_support_max_seq_len(max_seq_len, self.block_size)
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
self.parallel_config)
@ -191,6 +193,9 @@ class Worker:
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
sliding_window = getattr(self.model_config.hf_config, "sliding_window",
float("inf"))
# Add generation tokens.
max_context_len = 0
max_num_blocks_per_seq = 0
@ -211,10 +216,11 @@ class Worker:
context_len = seq_data.get_len()
position = context_len - 1
if sliding_window:
context_len = min(context_len, sliding_window)
input_positions.append(position)
block_table = seq_group_metadata.block_tables[seq_id]
generation_block_tables.append(block_table)
max_context_len = max(max_context_len, context_len)
max_num_blocks_per_seq = max(max_num_blocks_per_seq,
@ -226,6 +232,13 @@ class Worker:
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
if sliding_window:
assert self.cache_config is not None
sliding_window_blocks = (sliding_window //
self.cache_config.block_size)
block_table = block_table[-sliding_window_blocks:]
generation_block_tables.append(block_table)
# Optimization: Pad the input length to be a multiple of 8.
# This is required for utilizing the Tensor Cores in NVIDIA GPUs.
input_tokens = _pad_to_alignment(input_tokens, multiple_of=8)
@ -264,6 +277,7 @@ class Worker:
context_lens=context_lens_tensor,
max_context_len=max_context_len,
block_tables=block_tables_tensor,
sliding_window=sliding_window,
)
return tokens_tensor, positions_tensor, input_metadata