mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 10:26:07 +08:00
Refactor sliding window configuration to Transformers best practice (#21927)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
2a84fb422f
commit
c49848396d
@ -117,7 +117,7 @@ For models with interleaving sliding windows (e.g. `google/gemma-2-2b-it` and `m
|
||||
|
||||
To support a model with interleaving sliding windows, we need to take care of the following details:
|
||||
|
||||
- Make sure the model's `config.json` contains `sliding_window_pattern`. vLLM then sets `self.hf_text_config.interleaved_sliding_window` to the value of `self.hf_text_config.sliding_window` and deletes `sliding_window` from `self.hf_text_config`. The model will then be treated as a full-attention model.
|
||||
- Make sure the model's `config.json` contains `layer_types`.
|
||||
- In the modeling code, parse the correct sliding window value for every layer, and pass it to the attention layer's `per_layer_sliding_window` argument. For reference, check [this line](https://github.com/vllm-project/vllm/blob/996357e4808ca5eab97d4c97c7d25b3073f46aab/vllm/model_executor/models/llama.py#L171).
|
||||
|
||||
With these two steps, interleave sliding windows should work with the model.
|
||||
|
||||
@ -200,28 +200,6 @@ def test_disable_sliding_window(model_id_expected):
|
||||
assert model_config.max_model_len == expected
|
||||
|
||||
|
||||
def test_get_sliding_window():
|
||||
TEST_SLIDING_WINDOW = 4096
|
||||
# Test that the sliding window is correctly computed.
|
||||
# For Qwen1.5/Qwen2, get_sliding_window() should be None
|
||||
# when use_sliding_window is False.
|
||||
qwen2_model_config = ModelConfig("Qwen/Qwen1.5-7B")
|
||||
|
||||
qwen2_model_config.hf_config.use_sliding_window = False
|
||||
qwen2_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
|
||||
assert qwen2_model_config.get_sliding_window() is None
|
||||
|
||||
qwen2_model_config.hf_config.use_sliding_window = True
|
||||
assert qwen2_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
|
||||
|
||||
mistral_model_config = ModelConfig("mistralai/Mistral-7B-v0.1")
|
||||
mistral_model_config.hf_config.sliding_window = None
|
||||
assert mistral_model_config.get_sliding_window() is None
|
||||
|
||||
mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
|
||||
assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
|
||||
|
||||
|
||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||
reason="Xformers backend is not supported on ROCm.")
|
||||
def test_get_pooling_config():
|
||||
|
||||
@ -40,8 +40,9 @@ from vllm.transformers_utils.config import (
|
||||
ConfigFormat, get_config, get_hf_image_processor_config,
|
||||
get_hf_text_config, get_pooling_config,
|
||||
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
|
||||
maybe_override_with_speculators_target_model, try_get_generation_config,
|
||||
try_get_safetensors_metadata, try_get_tokenizer_config, uses_mrope)
|
||||
is_interleaved, maybe_override_with_speculators_target_model,
|
||||
try_get_generation_config, try_get_safetensors_metadata,
|
||||
try_get_tokenizer_config, uses_mrope)
|
||||
from vllm.transformers_utils.s3_utils import S3Model
|
||||
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
|
||||
# yapf conflicts with isort for this block
|
||||
@ -714,53 +715,31 @@ class ModelConfig:
|
||||
revision=self.revision,
|
||||
)
|
||||
|
||||
# Workaround for Gemma 2 which uses interleaved sliding window
|
||||
# attention, but it's not specified in its config.
|
||||
# TODO: remove this when Gemma 2 config updated in HuggingFace.
|
||||
if self.hf_text_config.model_type == "gemma2":
|
||||
self.hf_text_config.sliding_window_pattern = 2
|
||||
|
||||
# TODO: remove this when Gemma 3n config updated in HuggingFace.
|
||||
if self.hf_text_config.model_type == "gemma3n_text":
|
||||
# 4 sliding window attention followed by 1 full attention
|
||||
self.hf_text_config.sliding_window_pattern = "LLLLG"
|
||||
|
||||
sliding_window = getattr(self.hf_text_config, "sliding_window", None)
|
||||
sliding_window_pattern = getattr(self.hf_text_config,
|
||||
"sliding_window_pattern", None)
|
||||
has_interleaved_attention = sliding_window_pattern is not None or (
|
||||
isinstance(sliding_window, list))
|
||||
|
||||
if not self.disable_sliding_window and has_interleaved_attention:
|
||||
if not envs.VLLM_USE_V1 and (backend := envs.VLLM_ATTENTION_BACKEND
|
||||
) in ("XFORMERS", "FLASHINFER"):
|
||||
sliding_window_len_min = get_min_sliding_window(
|
||||
self.hf_text_config.sliding_window)
|
||||
|
||||
logger.warning_once(
|
||||
"%s has interleaved attention, which is currently not supported by the %s backend. Disabling sliding window and capping the max length to the sliding window size (%d).", # noqa: E501
|
||||
self.hf_text_config.model_type,
|
||||
backend,
|
||||
sliding_window_len_min,
|
||||
)
|
||||
self.disable_sliding_window = True
|
||||
else:
|
||||
# for a model with interleaved attention,
|
||||
# the scheduler and the model treat it as full attention
|
||||
# (i.e., not dropping any tokens outside the window).
|
||||
# only the attention layer itself is aware of the sliding
|
||||
# window, and use the window size to compute the attention.
|
||||
self.hf_text_config.interleaved_sliding_window = sliding_window
|
||||
|
||||
if hasattr(self.hf_text_config, "sliding_window"):
|
||||
delattr(self.hf_text_config, "sliding_window")
|
||||
|
||||
sliding_window = None
|
||||
# Interleaved attention is not supported by some backends in V0
|
||||
if (not self.disable_sliding_window
|
||||
and is_interleaved(self.hf_text_config)
|
||||
and not envs.VLLM_USE_V1
|
||||
and (backend := envs.VLLM_ATTENTION_BACKEND)
|
||||
in ("XFORMERS", "FLASHINFER")):
|
||||
logger.warning_once(
|
||||
"%s has interleaved attention, which is currently not "
|
||||
"supported by the %s backend. Disabling sliding window and "
|
||||
"capping the max length to the sliding window size (%d).",
|
||||
self.hf_text_config.model_type,
|
||||
backend,
|
||||
self.hf_text_config.sliding_window,
|
||||
)
|
||||
self.disable_sliding_window = True
|
||||
|
||||
self.original_max_model_len = self.max_model_len
|
||||
self.max_model_len = self.get_and_verify_max_len(self.max_model_len)
|
||||
self.multimodal_config = self._init_multimodal_config()
|
||||
|
||||
if self.disable_sliding_window:
|
||||
# Set after get_and_verify_max_len to ensure that max_model_len
|
||||
# can be correctly capped to sliding window size
|
||||
self.hf_text_config.sliding_window = None
|
||||
|
||||
if not self.skip_tokenizer_init:
|
||||
self._verify_tokenizer_mode()
|
||||
|
||||
@ -1322,27 +1301,10 @@ class ModelConfig:
|
||||
if self.use_async_output_proc:
|
||||
self.use_async_output_proc = False
|
||||
|
||||
def get_hf_config_sliding_window(
|
||||
self) -> Union[Optional[int], list[Optional[int]]]:
|
||||
"""Get the sliding window size, or None if disabled."""
|
||||
|
||||
# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
|
||||
# addition to sliding window size. We check if that field is present
|
||||
# and if it's False, return None.
|
||||
if (hasattr(self.hf_text_config, "use_sliding_window")
|
||||
and not self.hf_text_config.use_sliding_window):
|
||||
return None
|
||||
def get_sliding_window(self) -> Optional[int]:
|
||||
"""Get the sliding window size from the HF text config if present."""
|
||||
return getattr(self.hf_text_config, "sliding_window", None)
|
||||
|
||||
def get_sliding_window(self) -> Optional[Union[int, list[Optional[int]]]]:
|
||||
"""Get the sliding window size, or None if disabled.
|
||||
"""
|
||||
# If user disables sliding window, return None.
|
||||
if self.disable_sliding_window:
|
||||
return None
|
||||
# Otherwise get the value from the hf config.
|
||||
return self.get_hf_config_sliding_window()
|
||||
|
||||
def get_vocab_size(self) -> int:
|
||||
return getattr(self.hf_text_config, "vocab_size", 0)
|
||||
|
||||
@ -1762,7 +1724,7 @@ class ModelConfig:
|
||||
tokenizer_config=tokenizer_config,
|
||||
max_model_len=max_model_len,
|
||||
disable_sliding_window=self.disable_sliding_window,
|
||||
sliding_window_len=self.get_hf_config_sliding_window(),
|
||||
sliding_window=self.get_sliding_window(),
|
||||
spec_target_max_model_len=self.spec_target_max_model_len,
|
||||
encoder_config=self.encoder_config)
|
||||
logger.info("Using max model len %s", max_model_len)
|
||||
@ -3305,7 +3267,7 @@ def _get_and_verify_max_len(
|
||||
tokenizer_config: Optional[dict],
|
||||
max_model_len: Optional[int],
|
||||
disable_sliding_window: bool,
|
||||
sliding_window_len: Optional[Union[int, list[Optional[int]]]],
|
||||
sliding_window: Optional[int],
|
||||
spec_target_max_model_len: Optional[int] = None,
|
||||
encoder_config: Optional[Any] = None,
|
||||
) -> int:
|
||||
@ -3344,13 +3306,10 @@ def _get_and_verify_max_len(
|
||||
|
||||
# If sliding window is manually disabled, max_length should be less
|
||||
# than the sliding window length in the model config.
|
||||
if disable_sliding_window and sliding_window_len is not None:
|
||||
|
||||
sliding_window_len_min = get_min_sliding_window(sliding_window_len)
|
||||
max_len_key = "sliding_window" \
|
||||
if sliding_window_len_min < derived_max_model_len else max_len_key
|
||||
derived_max_model_len = min(derived_max_model_len,
|
||||
sliding_window_len_min)
|
||||
if (disable_sliding_window and sliding_window is not None
|
||||
and sliding_window < derived_max_model_len):
|
||||
max_len_key = "sliding_window"
|
||||
derived_max_model_len = sliding_window
|
||||
|
||||
# Consider model_max_length in tokenizer_config
|
||||
if tokenizer_config:
|
||||
@ -3451,14 +3410,6 @@ def _get_and_verify_max_len(
|
||||
return int(max_model_len)
|
||||
|
||||
|
||||
def get_min_sliding_window(
|
||||
sliding_window: Union[int, list[Optional[int]]]) -> int:
|
||||
if isinstance(sliding_window, list):
|
||||
return min(s for s in sliding_window if s is not None)
|
||||
|
||||
return sliding_window
|
||||
|
||||
|
||||
def get_served_model_name(model: str,
|
||||
served_model_name: Optional[Union[str, list[str]]]):
|
||||
"""
|
||||
|
||||
@ -39,6 +39,7 @@ from vllm.plugins import load_general_plugins
|
||||
from vllm.ray.lazy_utils import is_ray_initialized
|
||||
from vllm.reasoning import ReasoningParserManager
|
||||
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
|
||||
from vllm.transformers_utils.config import is_interleaved
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
|
||||
GiB_bytes, get_ip, is_in_ray_actor)
|
||||
@ -1081,6 +1082,13 @@ class EngineArgs:
|
||||
"DualChunkFlashAttention is not supported on V1 engine. "
|
||||
"To run the model in V0 engine, try set 'VLLM_USE_V1=0'")
|
||||
|
||||
sliding_window: Optional[int] = None
|
||||
if not is_interleaved(model_config.hf_text_config):
|
||||
# Only set CacheConfig.sliding_window if the model is all sliding
|
||||
# window. Otherwise CacheConfig.sliding_window will override the
|
||||
# global layers in interleaved sliding window models.
|
||||
sliding_window = model_config.get_sliding_window()
|
||||
|
||||
cache_config = CacheConfig(
|
||||
block_size=self.block_size,
|
||||
gpu_memory_utilization=self.gpu_memory_utilization,
|
||||
@ -1088,7 +1096,7 @@ class EngineArgs:
|
||||
cache_dtype=self.kv_cache_dtype,
|
||||
is_attention_free=model_config.is_attention_free,
|
||||
num_gpu_blocks_override=self.num_gpu_blocks_override,
|
||||
sliding_window=model_config.get_sliding_window(),
|
||||
sliding_window=sliding_window,
|
||||
enable_prefix_caching=self.enable_prefix_caching,
|
||||
prefix_caching_hash_algo=self.prefix_caching_hash_algo,
|
||||
cpu_offload_gb=self.cpu_offload_gb,
|
||||
|
||||
@ -182,21 +182,13 @@ class CohereAttention(nn.Module):
|
||||
)
|
||||
|
||||
# Model v2 has interleaved sliding windows, v1 does not
|
||||
interleaved_sliding_window = getattr(config,
|
||||
"interleaved_sliding_window",
|
||||
None)
|
||||
self.v1 = interleaved_sliding_window is None
|
||||
self.v1 = isinstance(config, CohereConfig)
|
||||
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
layer_has_sliding_window = (
|
||||
getattr(config, "sliding_window_pattern", False) and
|
||||
(layer_idx + 1) % self.config.sliding_window_pattern
|
||||
!= 0) or (getattr(config, "layer_types", False)
|
||||
and config.layer_types[layer_idx] == "sliding_attention")
|
||||
|
||||
self.sliding_window = (interleaved_sliding_window
|
||||
or config.sliding_window
|
||||
if layer_has_sliding_window else None)
|
||||
self.sliding_window = None
|
||||
if not self.v1:
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
if config.layer_types[layer_idx] == "sliding_attention":
|
||||
self.sliding_window = config.sliding_window
|
||||
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
|
||||
@ -159,25 +159,12 @@ class Exaone4Attention(nn.Module):
|
||||
if quant_config is not None and quant_config.get_name() == "gguf":
|
||||
is_neox_style = False
|
||||
|
||||
self.apply_all_layers = False # apply rotary embeddings to every layer.
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
interleaved_sliding_window = getattr(config,
|
||||
"interleaved_sliding_window",
|
||||
4096)
|
||||
sliding_window_pattern = getattr(config, "sliding_window_pattern",
|
||||
"LLLG")
|
||||
is_sliding = config.layer_types[layer_idx] == "sliding_attention"
|
||||
self.sliding_window = config.sliding_window if is_sliding else None
|
||||
|
||||
if sliding_window_pattern:
|
||||
layer_has_sliding_window = (
|
||||
layer_idx + 1) % sliding_window_pattern.__len__() != 0
|
||||
else:
|
||||
layer_has_sliding_window = False
|
||||
self.apply_all_layers = True
|
||||
|
||||
if layer_has_sliding_window:
|
||||
self.sliding_window = interleaved_sliding_window
|
||||
else:
|
||||
self.sliding_window = None
|
||||
# apply rotary embeddings to every layer
|
||||
self.apply_all_layers = not is_sliding
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
|
||||
@ -144,13 +144,10 @@ class Gemma2Attention(nn.Module):
|
||||
is_neox_style=True,
|
||||
)
|
||||
|
||||
# reference:
|
||||
# https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
use_sliding_window = (layer_idx % 2 == 0 and getattr(
|
||||
config, "interleaved_sliding_window", None) is not None)
|
||||
sliding_window = config.interleaved_sliding_window if \
|
||||
use_sliding_window else None
|
||||
is_sliding = config.layer_types[layer_idx] == "sliding_attention"
|
||||
sliding_window = config.sliding_window if is_sliding else None
|
||||
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
|
||||
@ -146,25 +146,19 @@ class Gemma3Attention(nn.Module):
|
||||
self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||
self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||
|
||||
# TODO(woosuk): Add reference to the original HF implementation.
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
self.is_sliding = (getattr(
|
||||
config, "interleaved_sliding_window", None) is not None and (bool(
|
||||
(layer_idx + 1) % config.sliding_window_pattern))) or (
|
||||
getattr(config, "layer_types", None) is not None
|
||||
and config.layer_types[layer_idx] == "sliding_attention")
|
||||
self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
|
||||
sliding_window = config.sliding_window if self.is_sliding else None
|
||||
|
||||
# Initialize the rotary embedding.
|
||||
if self.is_sliding:
|
||||
# Local attention. Override the values in config.json.
|
||||
self.rope_theta = config.rope_local_base_freq
|
||||
self.rope_scaling = {"rope_type": "default"}
|
||||
self.sliding_window = (config.interleaved_sliding_window
|
||||
or config.sliding_window)
|
||||
else:
|
||||
# Global attention. Use the values in config.json.
|
||||
self.rope_theta = config.rope_theta
|
||||
self.rope_scaling = config.rope_scaling
|
||||
self.sliding_window = None
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
@ -182,7 +176,7 @@ class Gemma3Attention(nn.Module):
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
logits_soft_cap=attn_logits_soft_cap,
|
||||
per_layer_sliding_window=self.sliding_window,
|
||||
per_layer_sliding_window=sliding_window,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
|
||||
@ -502,8 +502,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.sliding_window = getattr(config.text_config,
|
||||
"interleaved_sliding_window", None)
|
||||
|
||||
self.vision_tower = SiglipVisionModel(config.vision_config,
|
||||
quant_config,
|
||||
@ -690,11 +688,11 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
|
||||
global_attn_masks.append(global_attn_mask)
|
||||
|
||||
if self.sliding_window is not None:
|
||||
if (sliding_window := self.config.sliding_window) is not None:
|
||||
# Create a local causal mask with sliding window (1024).
|
||||
local_attn_mask = torch.ones_like(global_attn_mask)
|
||||
local_attn_mask = torch.tril(local_attn_mask,
|
||||
diagonal=-self.sliding_window)
|
||||
diagonal=-sliding_window)
|
||||
local_attn_mask = torch.where(local_attn_mask == 0,
|
||||
global_attn_mask, float("-inf"))
|
||||
local_attn_masks.append(local_attn_mask)
|
||||
|
||||
@ -313,17 +313,16 @@ class Gemma3nAttention(nn.Module):
|
||||
has_weight=False)
|
||||
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
is_sliding = config.layer_types[layer_idx] == "sliding_attention"
|
||||
self.sliding_window = config.sliding_window if is_sliding else None
|
||||
|
||||
is_sliding_window = (
|
||||
getattr(config, "interleaved_sliding_window", None) is not None
|
||||
and config.layer_types[layer_idx] == "sliding_attention")
|
||||
|
||||
if is_sliding_window:
|
||||
self.sliding_window = config.interleaved_sliding_window
|
||||
# Initialize the rotary embedding.
|
||||
if is_sliding:
|
||||
# Local attention. Override the values in config.json.
|
||||
rope_theta = config.rope_local_base_freq
|
||||
rope_scaling = {"rope_type": "default"}
|
||||
else:
|
||||
self.sliding_window = None
|
||||
# Global attention. Use the values in config.json.
|
||||
rope_theta = config.rope_theta
|
||||
rope_scaling = config.rope_scaling
|
||||
|
||||
|
||||
@ -248,9 +248,7 @@ class GritLM(LlamaForCausalLM, SupportsV0Only):
|
||||
|
||||
vllm_config.cache_config.sliding_window = None
|
||||
|
||||
for attr in ("sliding_window", "interleaved_sliding_window"):
|
||||
if hasattr(hf_config, attr):
|
||||
delattr(hf_config, attr)
|
||||
hf_config.sliding_window = None
|
||||
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
|
||||
|
||||
@ -167,18 +167,11 @@ class LlamaAttention(nn.Module):
|
||||
rope_scaling=rope_scaling,
|
||||
quant_config=quant_config)
|
||||
|
||||
if hasattr(config, "interleaved_sliding_window"):
|
||||
interleaved_sliding_window = config.interleaved_sliding_window
|
||||
if isinstance(interleaved_sliding_window, int):
|
||||
sliding_window = interleaved_sliding_window
|
||||
elif isinstance(interleaved_sliding_window, list):
|
||||
sw_idx = layer_idx % len(interleaved_sliding_window)
|
||||
sliding_window = interleaved_sliding_window[sw_idx]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{type(interleaved_sliding_window)} is not supported.")
|
||||
else:
|
||||
sliding_window = None
|
||||
sliding_window = None
|
||||
if layer_types := getattr(config, "layer_types", None):
|
||||
is_sliding = layer_types[layer_idx] == "sliding_attention"
|
||||
if is_sliding:
|
||||
sliding_window = config.sliding_window
|
||||
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
|
||||
@ -116,13 +116,8 @@ class SambaYAttention(nn.Module):
|
||||
self.Wqkv = nn.Linear(self.hidden_size, op_size, bias=True)
|
||||
|
||||
# disable sliding window for the second half of the model
|
||||
sliding_window = config.interleaved_sliding_window[layer_idx]
|
||||
if layer_idx >= config.num_hidden_layers // 2:
|
||||
assert sliding_window is None, \
|
||||
"sliding_window must be none for the second decoder"
|
||||
else:
|
||||
assert sliding_window is not None, \
|
||||
"sliding_window must be set for the first decoder"
|
||||
is_sliding = config.layer_types[layer_idx] == "sliding_attention"
|
||||
sliding_window = config.sliding_window if is_sliding else None
|
||||
|
||||
assert self.num_heads % 2 == 0, 'num_heads should be even'
|
||||
assert self.num_key_value_heads % 2 == 0, 'num_heads should be even'
|
||||
|
||||
@ -49,6 +49,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.config import is_interleaved
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
|
||||
@ -285,8 +286,7 @@ class Qwen2Model(nn.Module):
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
# TODO (@robertgshaw2): see if this can be moved out
|
||||
if (cache_config.sliding_window is not None
|
||||
and hasattr(config, "max_window_layers")):
|
||||
if is_interleaved(vllm_config.model_config.hf_text_config):
|
||||
assert config.max_window_layers == config.num_hidden_layers, (
|
||||
"Sliding window for some but all layers is not supported. "
|
||||
"This model uses sliding window but `max_window_layers` = {} "
|
||||
|
||||
@ -16,7 +16,7 @@
|
||||
# limitations under the License.
|
||||
"""Wrapper around `transformers` models"""
|
||||
from collections.abc import Iterable, Mapping
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from contextlib import contextmanager
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import regex as re
|
||||
@ -382,33 +382,6 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
|
||||
)
|
||||
|
||||
|
||||
class ConfigOverride:
|
||||
"""Context manager to temporarily override config attributes."""
|
||||
|
||||
def __init__(self, config: PretrainedConfig, **kwargs):
|
||||
self.config = config
|
||||
self.kwargs = kwargs
|
||||
self.kwargs_original = {}
|
||||
self.kwargs_delete = set()
|
||||
|
||||
def __enter__(self):
|
||||
"""Override config attributes."""
|
||||
for key, value in self.kwargs.items():
|
||||
if not hasattr(self.config, key):
|
||||
self.kwargs_delete.add(key)
|
||||
self.kwargs_original[key] = getattr(self.config, key, None)
|
||||
setattr(self.config, key, value)
|
||||
return self.config
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
"""Restore original config attributes."""
|
||||
for key, value in self.kwargs_original.items():
|
||||
if key in self.kwargs_delete:
|
||||
delattr(self.config, key)
|
||||
else:
|
||||
setattr(self.config, key, value)
|
||||
|
||||
|
||||
class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
embedding_modules = ["embed_tokens"
|
||||
@ -434,21 +407,11 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
# To be updated in child classes for use in `load_weights`
|
||||
self.skip_prefixes: Optional[list[str]] = None
|
||||
|
||||
# vLLM handles interleaved sliding window attention by creating a new
|
||||
# interleaved_sliding_window attribute and deleting the sliding_window
|
||||
# attribute. This breaks the constructors in Transformers so we
|
||||
# temporarily add the attribute back to construct the model.
|
||||
config_override = nullcontext()
|
||||
if hasattr(self.config, "interleaved_sliding_window"):
|
||||
config_override = ConfigOverride(
|
||||
self.config,
|
||||
sliding_window=self.config.interleaved_sliding_window)
|
||||
|
||||
# Set correct attn and init on "meta" to delay allocating GPU tensors
|
||||
# TODO: @raushan, use the public `model.set_attn_implementation()`
|
||||
# method once its checks are fixed in Transformers.
|
||||
self.text_config._attn_implementation = "vllm"
|
||||
with init_on_device_without_buffers("meta"), config_override:
|
||||
with init_on_device_without_buffers("meta"):
|
||||
self.model: PreTrainedModel = AutoModel.from_config(
|
||||
self.config,
|
||||
torch_dtype=self.model_config.dtype,
|
||||
@ -575,11 +538,10 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
attention_instances = {}
|
||||
for i in range(start, end):
|
||||
# Handle interleaved sliding window attention
|
||||
sliding_window = None
|
||||
if (hasattr(self.config, "interleaved_sliding_window")
|
||||
and hasattr(self.config, "sliding_window_pattern")
|
||||
and ((i + 1) % self.config.sliding_window_pattern > 0)):
|
||||
sliding_window = self.config.interleaved_sliding_window
|
||||
per_layer_sliding_window = None
|
||||
if (hasattr(self.config, "layer_types")
|
||||
and self.config.layer_types[i] == "sliding_attention"):
|
||||
per_layer_sliding_window = self.config.sliding_window
|
||||
|
||||
attention_instances[i] = Attention(
|
||||
num_heads=num_heads,
|
||||
@ -590,7 +552,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
num_kv_heads=num_kv_heads,
|
||||
cache_config=self.cache_config,
|
||||
quant_config=self.quant_config,
|
||||
per_layer_sliding_window=sliding_window,
|
||||
per_layer_sliding_window=per_layer_sliding_window,
|
||||
prefix=f"{i}.attn")
|
||||
return attention_instances
|
||||
|
||||
|
||||
@ -280,6 +280,17 @@ def is_encoder_decoder(config: PretrainedConfig) -> bool:
|
||||
return getattr(config, "is_encoder_decoder", False)
|
||||
|
||||
|
||||
def is_interleaved(config: PretrainedConfig) -> bool:
|
||||
"""
|
||||
Detect if the model with this config is used with interleaved attention.
|
||||
"""
|
||||
text_config = config.get_text_config()
|
||||
if layer_types := getattr(text_config, "layer_types", None):
|
||||
interleaved_types = {"full_attention", "sliding_attention"}
|
||||
return interleaved_types.issubset(layer_types)
|
||||
return False
|
||||
|
||||
|
||||
def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig:
|
||||
"""Remap config attributes to match the expected names."""
|
||||
for old_attr, new_attr in _CONFIG_ATTRS_MAPPING.items():
|
||||
@ -423,6 +434,23 @@ def get_config(
|
||||
raise e
|
||||
config = _maybe_remap_hf_config_attrs(config)
|
||||
|
||||
# Phi4Flash misuses this config as list[int]. Convert it to int and add
|
||||
# the layer_types list[str] to make it HF compatible
|
||||
if (config.model_type == "phi4flash"):
|
||||
# TODO: Remove after the following PR is merged:
|
||||
# https://huggingface.co/microsoft/Phi-4-mini-flash-reasoning/discussions/6
|
||||
if not hasattr(config, "layer_types"):
|
||||
config.layer_types = [
|
||||
"sliding_attention" if i < config.num_hidden_layers // 2
|
||||
and i % 2 == 1 else "full_attention"
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
# TODO: Remove after the following PR is merged:
|
||||
# https://huggingface.co/microsoft/Phi-4-mini-flash-reasoning/discussions/7
|
||||
if isinstance(config.sliding_window, list):
|
||||
config.sliding_window = next(
|
||||
filter(None, config.sliding_window), None)
|
||||
|
||||
elif config_format == ConfigFormat.MISTRAL:
|
||||
# This function loads a params.json config which
|
||||
# should be used when loading models in mistral format
|
||||
@ -434,6 +462,18 @@ def get_config(
|
||||
config_dict["max_position_embeddings"] = max_position_embeddings
|
||||
|
||||
config = adapt_config_dict(config_dict)
|
||||
|
||||
# Mistral configs may define sliding_window as list[int]. Convert it
|
||||
# to int and add the layer_types list[str] to make it HF compatible
|
||||
if ((sliding_window := getattr(config, "sliding_window", None))
|
||||
and isinstance(sliding_window, list)):
|
||||
pattern_repeats = config.num_hidden_layers // len(sliding_window)
|
||||
layer_types = sliding_window * pattern_repeats
|
||||
config.layer_types = [
|
||||
"full_attention" if layer_type is None else "sliding_attention"
|
||||
for layer_type in layer_types
|
||||
]
|
||||
config.sliding_window = next(filter(None, sliding_window), None)
|
||||
else:
|
||||
supported_formats = [
|
||||
fmt.value for fmt in ConfigFormat if fmt != ConfigFormat.AUTO
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user