mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-05 08:04:32 +08:00
[Interleaved ATTN] Support for Mistral-8B (#10591)
Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
16ee07f22a
commit
e7cfc4ef4c
@ -54,7 +54,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
|
|||||||
|
|
||||||
from .interfaces import SupportsLoRA, SupportsPP
|
from .interfaces import SupportsLoRA, SupportsPP
|
||||||
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
|
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
|
||||||
is_pp_missing_parameter,
|
extract_layer_index, is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
|
|
||||||
@ -114,6 +114,7 @@ class LlamaAttention(nn.Module):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
layer_idx = extract_layer_index(prefix)
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.total_num_heads = num_heads
|
self.total_num_heads = num_heads
|
||||||
@ -168,6 +169,18 @@ class LlamaAttention(nn.Module):
|
|||||||
rope_scaling=rope_scaling,
|
rope_scaling=rope_scaling,
|
||||||
is_neox_style=is_neox_style,
|
is_neox_style=is_neox_style,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if hasattr(config, "interleaved_sliding_window"):
|
||||||
|
if isinstance(config.interleaved_sliding_window, int):
|
||||||
|
sliding_window = config.interleaved_sliding_window
|
||||||
|
elif isinstance(config.interleaved_sliding_window, list):
|
||||||
|
sw_idx = layer_idx % len(config.interleaved_sliding_window)
|
||||||
|
sliding_window = config.interleaved_sliding_window[sw_idx]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{type(sliding_window)} is not supported.")
|
||||||
|
else:
|
||||||
|
sliding_window = None
|
||||||
|
|
||||||
self.attn = Attention(
|
self.attn = Attention(
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
@ -175,6 +188,7 @@ class LlamaAttention(nn.Module):
|
|||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
per_layer_sliding_window=sliding_window,
|
||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user