From e7cfc4ef4cc017e0a0229adff9f4b143b38fb421 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 30 Nov 2024 08:45:50 +0100 Subject: [PATCH] [Interleaved ATTN] Support for Mistral-8B (#10591) Signed-off-by: youkaichao Co-authored-by: youkaichao --- vllm/model_executor/models/llama.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index fe94bb352961b..ff0ab011a9158 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -54,7 +54,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, - is_pp_missing_parameter, + extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -114,6 +114,7 @@ class LlamaAttention(nn.Module): prefix: str = "", ) -> None: super().__init__() + layer_idx = extract_layer_index(prefix) self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads @@ -168,6 +169,18 @@ class LlamaAttention(nn.Module): rope_scaling=rope_scaling, is_neox_style=is_neox_style, ) + + if hasattr(config, "interleaved_sliding_window"): + if isinstance(config.interleaved_sliding_window, int): + sliding_window = config.interleaved_sliding_window + elif isinstance(config.interleaved_sliding_window, list): + sw_idx = layer_idx % len(config.interleaved_sliding_window) + sliding_window = config.interleaved_sliding_window[sw_idx] + else: + raise ValueError(f"{type(sliding_window)} is not supported.") + else: + sliding_window = None + self.attn = Attention( self.num_heads, self.head_dim, @@ -175,6 +188,7 @@ class LlamaAttention(nn.Module): num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, + per_layer_sliding_window=sliding_window, prefix=f"{prefix}.attn", )