diff --git a/requirements.txt b/requirements.txt index d398c7220eca8..e7491957fb8f3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ sentencepiece # Required for LLaMA tokenizer. numpy torch >= 2.0.0 transformers >= 4.33.1 # Required for Code Llama. -xformers >= 0.0.21 +xformers >= 0.0.22 fastapi uvicorn[standard] pydantic < 2 # Required for OpenAI server. diff --git a/vllm/config.py b/vllm/config.py index 5fc7696b83111..90d574bb3b2aa 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -187,10 +187,12 @@ class CacheConfig: block_size: int, gpu_memory_utilization: float, swap_space: int, + sliding_window: Optional[int] = None, ) -> None: self.block_size = block_size self.gpu_memory_utilization = gpu_memory_utilization self.swap_space_bytes = swap_space * _GB + self.sliding_window = sliding_window self._verify_args() # Will be set after profiling. diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index a8262c4722c5e..57349e7fe7f92 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -63,10 +63,18 @@ class BlockSpaceManager: num_gpu_blocks: int, num_cpu_blocks: int, watermark: float = 0.01, + sliding_window: Optional[int] = None, ) -> None: self.block_size = block_size self.num_total_gpu_blocks = num_gpu_blocks self.num_total_cpu_blocks = num_cpu_blocks + + self.block_sliding_window = None + if sliding_window is not None: + assert sliding_window % block_size == 0, (sliding_window, + block_size) + self.block_sliding_window = sliding_window // block_size + self.watermark = watermark assert watermark >= 0.0 @@ -83,6 +91,9 @@ class BlockSpaceManager: # the same prompt. This may not be true for preempted sequences. seq = seq_group.get_seqs()[0] num_required_blocks = len(seq.logical_token_blocks) + if self.block_sliding_window is not None: + num_required_blocks = min(num_required_blocks, + self.block_sliding_window) num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() # Use watermark to avoid frequent cache eviction. return (num_free_gpu_blocks - num_required_blocks >= @@ -95,8 +106,12 @@ class BlockSpaceManager: # Allocate new physical token blocks that will store the prompt tokens. block_table: BlockTable = [] - for _ in range(len(seq.logical_token_blocks)): - block = self.gpu_allocator.allocate() + for logical_idx in range(len(seq.logical_token_blocks)): + if (self.block_sliding_window is not None + and logical_idx >= self.block_sliding_window): + block = block_table[logical_idx % self.block_sliding_window] + else: + block = self.gpu_allocator.allocate() # Set the reference counts of the token blocks. block.ref_count = seq_group.num_seqs() block_table.append(block) @@ -118,11 +133,17 @@ class BlockSpaceManager: block_table = self.block_tables[seq.seq_id] if len(block_table) < len(logical_blocks): - # The sequence has a new logical block. - # Allocate a new physical block. - block = self.gpu_allocator.allocate() - block_table.append(block) - return None + if (self.block_sliding_window + and len(block_table) >= self.block_sliding_window): + # re-use a block + block_table.append(block_table[len(block_table) % + self.block_sliding_window]) + else: + # The sequence has a new logical block. + # Allocate a new physical block. + block = self.gpu_allocator.allocate() + block_table.append(block) + return None # We want to append the token to the last physical block. last_block = block_table[-1] @@ -154,9 +175,7 @@ class BlockSpaceManager: for seq in seq_group.get_seqs(): if seq.is_finished(): continue - block_table = self.block_tables[seq.seq_id] - for block in block_table: - blocks.add(block) + blocks.update(self.block_tables[seq.seq_id]) return list(blocks) def can_swap_in(self, seq_group: SequenceGroup) -> bool: @@ -224,7 +243,7 @@ class BlockSpaceManager: return block_number_mapping def _free_block_table(self, block_table: BlockTable) -> None: - for block in block_table: + for block in set(block_table): if block.device == Device.GPU: self.gpu_allocator.free(block) else: diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index b5e0da48d1d19..18d6b3ed2ea75 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -73,7 +73,7 @@ class Scheduler: block_size=self.cache_config.block_size, num_gpu_blocks=self.cache_config.num_gpu_blocks, num_cpu_blocks=self.cache_config.num_cpu_blocks, - ) + sliding_window=self.cache_config.sliding_window) # TODO(zhuohan): Use deque instead of list for better performance. # Sequence groups in the WAITING state. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 237782df96f84..8951a98e4159a 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -176,9 +176,9 @@ class EngineArgs: self.download_dir, self.load_format, self.dtype, self.seed, self.revision, self.max_model_len, self.quantization) - cache_config = CacheConfig(self.block_size, - self.gpu_memory_utilization, - self.swap_space) + cache_config = CacheConfig( + self.block_size, self.gpu_memory_utilization, self.swap_space, + getattr(model_config.hf_config, 'sliding_window', None)) parallel_config = ParallelConfig(self.pipeline_parallel_size, self.tensor_parallel_size, self.worker_use_ray) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 9922bda7d14c6..3e1026bfaefce 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -86,6 +86,8 @@ class LLMEngine: self.model_config = model_config self.cache_config = cache_config + assert self.cache_config.sliding_window == getattr( + self.model_config.hf_config, "sliding_window", None) self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.log_stats = log_stats diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index 1b0bc7327f7a9..a0a62034aa24e 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple import torch from xformers.ops import AttentionBias @@ -29,6 +29,7 @@ class InputMetadata: context_lens: torch.Tensor, max_context_len: int, block_tables: torch.Tensor, + sliding_window: Optional[int] = None, ) -> None: self.seq_groups = seq_groups self.seq_data = seq_data @@ -38,6 +39,24 @@ class InputMetadata: self.max_context_len = max_context_len self.block_tables = block_tables + self.to_cache = None + if sliding_window is not None: + # We need to keep the positions of sliding windows within + # the key / value tables, this is helpful to know which + # elements we need to cache and where + to_cache, start_idx = [], 0 + for prompt_len in self.prompt_lens: + to_cache.extend( + range( + start_idx + max(0, prompt_len - sliding_window), + start_idx + prompt_len, + )) + start_idx += prompt_len + to_cache.extend(range(start_idx, slot_mapping.shape[0])) + self.to_cache = torch.tensor(to_cache, + dtype=torch.int32, + device=self.slot_mapping.device) + self.num_prompts = len(prompt_lens) self.num_prompt_tokens = sum(prompt_lens) self.num_generation_tokens = context_lens.shape[0] diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index ccae52120b692..b1d0588d97f7e 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -58,12 +58,14 @@ class PagedAttention(nn.Module): num_heads: int, head_size: int, scale: float, - num_kv_heads: Optional[int] = None) -> None: + num_kv_heads: Optional[int] = None, + sliding_window: Optional[int] = None) -> None: super().__init__() self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.sliding_window = sliding_window assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -86,6 +88,8 @@ class PagedAttention(nn.Module): return prompt_lens = input_metadata.prompt_lens attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens) + if self.sliding_window is not None: + attn_bias = attn_bias.make_local_attention(self.sliding_window) input_metadata.attn_bias.append(attn_bias) def multi_query_kv_attention( @@ -223,12 +227,20 @@ class PagedAttention(nn.Module): if (num_valid_tokens > 0 and key_cache is not None and value_cache is not None): # The stride is 3 because the key and value are sliced from qkv. + key_to_cache = key[:num_valid_tokens] + value_to_cache = value[:num_valid_tokens] + slot_mapping = input_metadata.slot_mapping + if input_metadata.to_cache is not None: + key_to_cache = key_to_cache[input_metadata.to_cache] + value_to_cache = value_to_cache[input_metadata.to_cache] + slot_mapping = slot_mapping[input_metadata.to_cache] + cache_ops.reshape_and_cache( - key[:num_valid_tokens], - value[:num_valid_tokens], + key_to_cache, + value_to_cache, key_cache, value_cache, - input_metadata.slot_mapping, + slot_mapping, ) if input_metadata.num_generation_tokens > 0: @@ -262,8 +274,13 @@ class PagedAttentionWithRoPE(PagedAttention): num_kv_heads: Optional[int] = None, is_neox_style: bool = True, rope_scaling: Optional[Dict[str, Any]] = None, + sliding_window: Optional[int] = None, ) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads) + super().__init__(num_heads, + head_size, + scale, + num_kv_heads, + sliding_window=sliding_window) if rope_scaling is None: self.rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 526b4f8b5c877..951ba1f0ceba6 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -25,6 +25,7 @@ _MODEL_REGISTRY = { "InternLMForCausalLM": InternLMForCausalLM, "LlamaForCausalLM": LlamaForCausalLM, "LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-* + "MistralForCausalLM": MistralForCausalLM, "MPTForCausalLM": MPTForCausalLM, "OPTForCausalLM": OPTForCausalLM, "QWenLMHeadModel": QWenLMHeadModel, diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index f20e5d8e6f20e..01d85355b2979 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -12,6 +12,7 @@ from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.mpt import MPTForCausalLM from vllm.model_executor.models.opt import OPTForCausalLM from vllm.model_executor.models.qwen import QWenLMHeadModel +from vllm.model_executor.models.mistral import MistralForCausalLM __all__ = [ "AquilaForCausalLM", @@ -28,4 +29,5 @@ __all__ = [ "MPTForCausalLM", "OPTForCausalLM", "QWenLMHeadModel", + "MistralForCausalLM", ] diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py new file mode 100644 index 0000000000000..7bc8575756516 --- /dev/null +++ b/vllm/model_executor/models/mistral.py @@ -0,0 +1,404 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only LLaMA model compatible with HuggingFace weights. + +The input of the model is flattened to a 1D tensor of tokens. The model uses +InputMetadata to extract the original 2D shape of the input. +""" +from typing import List, Optional, Tuple + +import torch +from torch import nn +from vllm.transformers_utils.configs.mistral import MistralConfig + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.quantized_linear import ParallelLinear +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.parallel_utils.tensor_parallel import ( + VocabParallelEmbedding) +from vllm.model_executor.quantization_utils import QuantizationConfig +from vllm.model_executor.weight_utils import ( + convert_pyslice_to_tensor, hf_model_weights_iterator, + load_tensor_parallel_weights, load_padded_tensor_parallel_vocab) +from vllm.sequence import SamplerOutput + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class MistralMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.gate_up_proj = ParallelLinear.column(hidden_size, + 2 * intermediate_size, + bias=False, + gather_output=False, + perform_initialization=False, + quant_config=quant_config) + self.down_proj = ParallelLinear.row(intermediate_size, + hidden_size, + bias=False, + input_is_parallel=True, + perform_initialization=False, + quant_config=quant_config) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class MistralAttention(nn.Module): + + def __init__(self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + quant_config: Optional[QuantizationConfig] = None, + sliding_window: Optional[int] = None) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + assert self.total_num_kv_heads % tp_size == 0 + self.num_kv_heads = self.total_num_kv_heads // tp_size + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.sliding_window = sliding_window + + self.qkv_proj = ParallelLinear.column( + hidden_size, + (self.total_num_heads + 2 * self.total_num_kv_heads) * + self.head_dim, + bias=False, + gather_output=False, + perform_initialization=False, + quant_config=quant_config, + ) + self.o_proj = ParallelLinear.row( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + input_is_parallel=True, + perform_initialization=False, + quant_config=quant_config, + ) + self.attn = PagedAttentionWithRoPE(self.num_heads, + self.head_dim, + self.scaling, + base=self.rope_theta, + max_position=max_position, + rotary_dim=self.head_dim, + num_kv_heads=self.num_kv_heads, + sliding_window=self.sliding_window) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + k_cache, v_cache = kv_cache + attn_output = self.attn(positions, q, k, v, k_cache, v_cache, + input_metadata, cache_event) + output, _ = self.o_proj(attn_output) + return output + + +class MistralDecoderLayer(nn.Module): + + def __init__( + self, + config: MistralConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 10000) + self.self_attn = MistralAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + quant_config=quant_config, + sliding_window=config.sliding_window) + self.mlp = MistralMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + cache_event=cache_event, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class MistralModel(nn.Module): + + def __init__( + self, + config: MistralConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + vocab_size = ((config.vocab_size + 63) // 64) * 64 + self.embed_tokens = VocabParallelEmbedding( + vocab_size, config.hidden_size, perform_initialization=False) + self.layers = nn.ModuleList([ + MistralDecoderLayer(config, quant_config) + for _ in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + for i in range(len(self.layers)): + if cache_events is None: + cache_event = None + else: + cache_event = cache_events[i] + layer = self.layers[i] + hidden_states = layer( + positions, + hidden_states, + kv_caches[i], + input_metadata, + cache_event, + ) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class MistralForCausalLM(nn.Module): + + def __init__( + self, + config: MistralConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = MistralModel(config, quant_config) + vocab_size = ((config.vocab_size + 63) // 64) * 64 + # NOTE: The LM head is not quantized. + self.lm_head = ParallelLinear.column(config.hidden_size, + vocab_size, + bias=False, + gather_output=False, + perform_initialization=False, + quant_config=None) + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> SamplerOutput: + hidden_states = self.model(input_ids, positions, kv_caches, + input_metadata, cache_events) + next_tokens = self.sampler(self.lm_head.weight, hidden_states, + input_metadata) + return next_tokens + + _column_parallel_layers = [] + _row_parallel_layers = ["o_proj", "down_proj"] + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + if self.quant_config is None: + weight_suffixes = ["weight"] + else: + weight_suffixes = self.quant_config.get_tp_tensor_names() + + column_parallel_weights: List[str] = [] + for layer in self._column_parallel_layers: + for suffix in weight_suffixes: + column_parallel_weights.append(f"{layer}.{suffix}") + row_parallel_weights: List[str] = [] + for layer in self._row_parallel_layers: + for suffix in weight_suffixes: + row_parallel_weights.append(f"{layer}.{suffix}") + + tp_size = get_tensor_model_parallel_world_size() + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + q_proj_shard_size = (self.config.hidden_size // tp_size) + kv_proj_shard_size = (self.config.hidden_size // + self.config.num_attention_heads * + self.config.num_key_value_heads // tp_size) + attention_weight_specs = [ + # (weight_name, shard_size, offset) + ("q_proj", q_proj_shard_size, 0), + ("k_proj", kv_proj_shard_size, q_proj_shard_size), + ("v_proj", kv_proj_shard_size, + q_proj_shard_size + kv_proj_shard_size), + ] + state_dict = self.state_dict() + + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if "rotary_emb.inv_freq" in name: + continue + + is_packed = False + is_transposed = False + if self.quant_config is not None: + is_packed = self.quant_config.is_packed(name) + is_transposed = self.quant_config.is_transposed(name) + if is_transposed: + loaded_weight = convert_pyslice_to_tensor(loaded_weight) + loaded_weight = loaded_weight.T + + is_attention_weight = False + for weight_name, shard_size, offset in attention_weight_specs: + if weight_name not in name: + continue + param = state_dict[name.replace(weight_name, "qkv_proj")] + if is_transposed: + param = param.T + + if is_packed: + shard_size //= self.quant_config.pack_factor + offset //= self.quant_config.pack_factor + + loaded_weight = loaded_weight[ + shard_size * tensor_model_parallel_rank:shard_size * + (tensor_model_parallel_rank + 1)] + param_slice = param.data[offset:offset + shard_size] + assert param_slice.shape == loaded_weight.shape + + param_slice.copy_(loaded_weight) + is_attention_weight = True + break + if is_attention_weight: + continue + + is_gate_up_weight = False + for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): + if weight_name not in name: + continue + param = state_dict[name.replace(weight_name, "gate_up_proj")] + if is_transposed: + param = param.T + + shard_size = param.shape[0] // 2 + loaded_weight = loaded_weight[ + shard_size * tensor_model_parallel_rank:shard_size * + (tensor_model_parallel_rank + 1)] + param_slice = param.data[shard_size * stride_id:shard_size * + (stride_id + 1)] + assert param_slice.shape == loaded_weight.shape + param_slice.copy_(loaded_weight) + is_gate_up_weight = True + break + if is_gate_up_weight: + continue + + param = state_dict[name] + if is_transposed: + param = param.T + + if "embed_tokens" in name or "lm_head" in name: + load_padded_tensor_parallel_vocab(param, loaded_weight, + tensor_model_parallel_rank) + continue + + load_tensor_parallel_weights(param, loaded_weight, name, + column_parallel_weights, + row_parallel_weights, + tensor_model_parallel_rank) diff --git a/vllm/transformers_utils/configs/mistral.py b/vllm/transformers_utils/configs/mistral.py new file mode 100644 index 0000000000000..0a7d9a8efa349 --- /dev/null +++ b/vllm/transformers_utils/configs/mistral.py @@ -0,0 +1,66 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Mistral-7B-v0.1 configuration""" +from transformers.configuration_utils import PretrainedConfig + + +class MistralConfig(PretrainedConfig): + model_type = "mistral" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=10000.0, + sliding_window=4096, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 3239a819794ed..e14086f271073 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -137,8 +137,10 @@ class Worker: self.cache_config = cache_config self.block_size = cache_config.block_size - _check_if_can_support_max_seq_len(self.scheduler_config.max_model_len, - self.block_size) + max_seq_len = min(self.scheduler_config.max_model_len, + cache_config.sliding_window or float("inf")) + + _check_if_can_support_max_seq_len(max_seq_len, self.block_size) self.cache_engine = CacheEngine(self.cache_config, self.model_config, self.parallel_config) @@ -191,6 +193,9 @@ class Worker: slot = block_number * self.block_size + block_offset slot_mapping.append(slot) + sliding_window = getattr(self.model_config.hf_config, "sliding_window", + float("inf")) + # Add generation tokens. max_context_len = 0 max_num_blocks_per_seq = 0 @@ -211,10 +216,11 @@ class Worker: context_len = seq_data.get_len() position = context_len - 1 + if sliding_window: + context_len = min(context_len, sliding_window) input_positions.append(position) block_table = seq_group_metadata.block_tables[seq_id] - generation_block_tables.append(block_table) max_context_len = max(max_context_len, context_len) max_num_blocks_per_seq = max(max_num_blocks_per_seq, @@ -226,6 +232,13 @@ class Worker: slot = block_number * self.block_size + block_offset slot_mapping.append(slot) + if sliding_window: + assert self.cache_config is not None + sliding_window_blocks = (sliding_window // + self.cache_config.block_size) + block_table = block_table[-sliding_window_blocks:] + generation_block_tables.append(block_table) + # Optimization: Pad the input length to be a multiple of 8. # This is required for utilizing the Tensor Cores in NVIDIA GPUs. input_tokens = _pad_to_alignment(input_tokens, multiple_of=8) @@ -264,6 +277,7 @@ class Worker: context_lens=context_lens_tensor, max_context_len=max_context_len, block_tables=block_tables_tensor, + sliding_window=sliding_window, ) return tokens_tensor, positions_tensor, input_metadata