mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-30 12:51:49 +08:00
[Interleaved ATTN] Support for Mistral-8B (#10591)
Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
16ee07f22a
commit
e7cfc4ef4c
@ -54,7 +54,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
|
||||
is_pp_missing_parameter,
|
||||
extract_layer_index, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
@ -114,6 +114,7 @@ class LlamaAttention(nn.Module):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
@ -168,6 +169,18 @@ class LlamaAttention(nn.Module):
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=is_neox_style,
|
||||
)
|
||||
|
||||
if hasattr(config, "interleaved_sliding_window"):
|
||||
if isinstance(config.interleaved_sliding_window, int):
|
||||
sliding_window = config.interleaved_sliding_window
|
||||
elif isinstance(config.interleaved_sliding_window, list):
|
||||
sw_idx = layer_idx % len(config.interleaved_sliding_window)
|
||||
sliding_window = config.interleaved_sliding_window[sw_idx]
|
||||
else:
|
||||
raise ValueError(f"{type(sliding_window)} is not supported.")
|
||||
else:
|
||||
sliding_window = None
|
||||
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
@ -175,6 +188,7 @@ class LlamaAttention(nn.Module):
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
per_layer_sliding_window=sliding_window,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user