Refactor sliding window configuration to Transformers best practice (#21927)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-08-10 04:50:48 +01:00 committed by GitHub
parent 2a84fb422f
commit c49848396d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 123 additions and 231 deletions

View File

@ -117,7 +117,7 @@ For models with interleaving sliding windows (e.g. `google/gemma-2-2b-it` and `m
To support a model with interleaving sliding windows, we need to take care of the following details: To support a model with interleaving sliding windows, we need to take care of the following details:
- Make sure the model's `config.json` contains `sliding_window_pattern`. vLLM then sets `self.hf_text_config.interleaved_sliding_window` to the value of `self.hf_text_config.sliding_window` and deletes `sliding_window` from `self.hf_text_config`. The model will then be treated as a full-attention model. - Make sure the model's `config.json` contains `layer_types`.
- In the modeling code, parse the correct sliding window value for every layer, and pass it to the attention layer's `per_layer_sliding_window` argument. For reference, check [this line](https://github.com/vllm-project/vllm/blob/996357e4808ca5eab97d4c97c7d25b3073f46aab/vllm/model_executor/models/llama.py#L171). - In the modeling code, parse the correct sliding window value for every layer, and pass it to the attention layer's `per_layer_sliding_window` argument. For reference, check [this line](https://github.com/vllm-project/vllm/blob/996357e4808ca5eab97d4c97c7d25b3073f46aab/vllm/model_executor/models/llama.py#L171).
With these two steps, interleave sliding windows should work with the model. With these two steps, interleave sliding windows should work with the model.

View File

@ -200,28 +200,6 @@ def test_disable_sliding_window(model_id_expected):
assert model_config.max_model_len == expected assert model_config.max_model_len == expected
def test_get_sliding_window():
TEST_SLIDING_WINDOW = 4096
# Test that the sliding window is correctly computed.
# For Qwen1.5/Qwen2, get_sliding_window() should be None
# when use_sliding_window is False.
qwen2_model_config = ModelConfig("Qwen/Qwen1.5-7B")
qwen2_model_config.hf_config.use_sliding_window = False
qwen2_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
assert qwen2_model_config.get_sliding_window() is None
qwen2_model_config.hf_config.use_sliding_window = True
assert qwen2_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
mistral_model_config = ModelConfig("mistralai/Mistral-7B-v0.1")
mistral_model_config.hf_config.sliding_window = None
assert mistral_model_config.get_sliding_window() is None
mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
@pytest.mark.skipif(current_platform.is_rocm(), @pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.") reason="Xformers backend is not supported on ROCm.")
def test_get_pooling_config(): def test_get_pooling_config():

View File

@ -40,8 +40,9 @@ from vllm.transformers_utils.config import (
ConfigFormat, get_config, get_hf_image_processor_config, ConfigFormat, get_config, get_hf_image_processor_config,
get_hf_text_config, get_pooling_config, get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder, get_sentence_transformer_tokenizer_config, is_encoder_decoder,
maybe_override_with_speculators_target_model, try_get_generation_config, is_interleaved, maybe_override_with_speculators_target_model,
try_get_safetensors_metadata, try_get_tokenizer_config, uses_mrope) try_get_generation_config, try_get_safetensors_metadata,
try_get_tokenizer_config, uses_mrope)
from vllm.transformers_utils.s3_utils import S3Model from vllm.transformers_utils.s3_utils import S3Model
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
@ -714,53 +715,31 @@ class ModelConfig:
revision=self.revision, revision=self.revision,
) )
# Workaround for Gemma 2 which uses interleaved sliding window # Interleaved attention is not supported by some backends in V0
# attention, but it's not specified in its config. if (not self.disable_sliding_window
# TODO: remove this when Gemma 2 config updated in HuggingFace. and is_interleaved(self.hf_text_config)
if self.hf_text_config.model_type == "gemma2": and not envs.VLLM_USE_V1
self.hf_text_config.sliding_window_pattern = 2 and (backend := envs.VLLM_ATTENTION_BACKEND)
in ("XFORMERS", "FLASHINFER")):
# TODO: remove this when Gemma 3n config updated in HuggingFace.
if self.hf_text_config.model_type == "gemma3n_text":
# 4 sliding window attention followed by 1 full attention
self.hf_text_config.sliding_window_pattern = "LLLLG"
sliding_window = getattr(self.hf_text_config, "sliding_window", None)
sliding_window_pattern = getattr(self.hf_text_config,
"sliding_window_pattern", None)
has_interleaved_attention = sliding_window_pattern is not None or (
isinstance(sliding_window, list))
if not self.disable_sliding_window and has_interleaved_attention:
if not envs.VLLM_USE_V1 and (backend := envs.VLLM_ATTENTION_BACKEND
) in ("XFORMERS", "FLASHINFER"):
sliding_window_len_min = get_min_sliding_window(
self.hf_text_config.sliding_window)
logger.warning_once( logger.warning_once(
"%s has interleaved attention, which is currently not supported by the %s backend. Disabling sliding window and capping the max length to the sliding window size (%d).", # noqa: E501 "%s has interleaved attention, which is currently not "
"supported by the %s backend. Disabling sliding window and "
"capping the max length to the sliding window size (%d).",
self.hf_text_config.model_type, self.hf_text_config.model_type,
backend, backend,
sliding_window_len_min, self.hf_text_config.sliding_window,
) )
self.disable_sliding_window = True self.disable_sliding_window = True
else:
# for a model with interleaved attention,
# the scheduler and the model treat it as full attention
# (i.e., not dropping any tokens outside the window).
# only the attention layer itself is aware of the sliding
# window, and use the window size to compute the attention.
self.hf_text_config.interleaved_sliding_window = sliding_window
if hasattr(self.hf_text_config, "sliding_window"):
delattr(self.hf_text_config, "sliding_window")
sliding_window = None
self.original_max_model_len = self.max_model_len self.original_max_model_len = self.max_model_len
self.max_model_len = self.get_and_verify_max_len(self.max_model_len) self.max_model_len = self.get_and_verify_max_len(self.max_model_len)
self.multimodal_config = self._init_multimodal_config() self.multimodal_config = self._init_multimodal_config()
if self.disable_sliding_window:
# Set after get_and_verify_max_len to ensure that max_model_len
# can be correctly capped to sliding window size
self.hf_text_config.sliding_window = None
if not self.skip_tokenizer_init: if not self.skip_tokenizer_init:
self._verify_tokenizer_mode() self._verify_tokenizer_mode()
@ -1322,27 +1301,10 @@ class ModelConfig:
if self.use_async_output_proc: if self.use_async_output_proc:
self.use_async_output_proc = False self.use_async_output_proc = False
def get_hf_config_sliding_window( def get_sliding_window(self) -> Optional[int]:
self) -> Union[Optional[int], list[Optional[int]]]: """Get the sliding window size from the HF text config if present."""
"""Get the sliding window size, or None if disabled."""
# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
# addition to sliding window size. We check if that field is present
# and if it's False, return None.
if (hasattr(self.hf_text_config, "use_sliding_window")
and not self.hf_text_config.use_sliding_window):
return None
return getattr(self.hf_text_config, "sliding_window", None) return getattr(self.hf_text_config, "sliding_window", None)
def get_sliding_window(self) -> Optional[Union[int, list[Optional[int]]]]:
"""Get the sliding window size, or None if disabled.
"""
# If user disables sliding window, return None.
if self.disable_sliding_window:
return None
# Otherwise get the value from the hf config.
return self.get_hf_config_sliding_window()
def get_vocab_size(self) -> int: def get_vocab_size(self) -> int:
return getattr(self.hf_text_config, "vocab_size", 0) return getattr(self.hf_text_config, "vocab_size", 0)
@ -1762,7 +1724,7 @@ class ModelConfig:
tokenizer_config=tokenizer_config, tokenizer_config=tokenizer_config,
max_model_len=max_model_len, max_model_len=max_model_len,
disable_sliding_window=self.disable_sliding_window, disable_sliding_window=self.disable_sliding_window,
sliding_window_len=self.get_hf_config_sliding_window(), sliding_window=self.get_sliding_window(),
spec_target_max_model_len=self.spec_target_max_model_len, spec_target_max_model_len=self.spec_target_max_model_len,
encoder_config=self.encoder_config) encoder_config=self.encoder_config)
logger.info("Using max model len %s", max_model_len) logger.info("Using max model len %s", max_model_len)
@ -3305,7 +3267,7 @@ def _get_and_verify_max_len(
tokenizer_config: Optional[dict], tokenizer_config: Optional[dict],
max_model_len: Optional[int], max_model_len: Optional[int],
disable_sliding_window: bool, disable_sliding_window: bool,
sliding_window_len: Optional[Union[int, list[Optional[int]]]], sliding_window: Optional[int],
spec_target_max_model_len: Optional[int] = None, spec_target_max_model_len: Optional[int] = None,
encoder_config: Optional[Any] = None, encoder_config: Optional[Any] = None,
) -> int: ) -> int:
@ -3344,13 +3306,10 @@ def _get_and_verify_max_len(
# If sliding window is manually disabled, max_length should be less # If sliding window is manually disabled, max_length should be less
# than the sliding window length in the model config. # than the sliding window length in the model config.
if disable_sliding_window and sliding_window_len is not None: if (disable_sliding_window and sliding_window is not None
and sliding_window < derived_max_model_len):
sliding_window_len_min = get_min_sliding_window(sliding_window_len) max_len_key = "sliding_window"
max_len_key = "sliding_window" \ derived_max_model_len = sliding_window
if sliding_window_len_min < derived_max_model_len else max_len_key
derived_max_model_len = min(derived_max_model_len,
sliding_window_len_min)
# Consider model_max_length in tokenizer_config # Consider model_max_length in tokenizer_config
if tokenizer_config: if tokenizer_config:
@ -3451,14 +3410,6 @@ def _get_and_verify_max_len(
return int(max_model_len) return int(max_model_len)
def get_min_sliding_window(
sliding_window: Union[int, list[Optional[int]]]) -> int:
if isinstance(sliding_window, list):
return min(s for s in sliding_window if s is not None)
return sliding_window
def get_served_model_name(model: str, def get_served_model_name(model: str,
served_model_name: Optional[Union[str, list[str]]]): served_model_name: Optional[Union[str, list[str]]]):
""" """

View File

@ -39,6 +39,7 @@ from vllm.plugins import load_general_plugins
from vllm.ray.lazy_utils import is_ray_initialized from vllm.ray.lazy_utils import is_ray_initialized
from vllm.reasoning import ReasoningParserManager from vllm.reasoning import ReasoningParserManager
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.config import is_interleaved
from vllm.transformers_utils.utils import check_gguf_file from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser, from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
GiB_bytes, get_ip, is_in_ray_actor) GiB_bytes, get_ip, is_in_ray_actor)
@ -1081,6 +1082,13 @@ class EngineArgs:
"DualChunkFlashAttention is not supported on V1 engine. " "DualChunkFlashAttention is not supported on V1 engine. "
"To run the model in V0 engine, try set 'VLLM_USE_V1=0'") "To run the model in V0 engine, try set 'VLLM_USE_V1=0'")
sliding_window: Optional[int] = None
if not is_interleaved(model_config.hf_text_config):
# Only set CacheConfig.sliding_window if the model is all sliding
# window. Otherwise CacheConfig.sliding_window will override the
# global layers in interleaved sliding window models.
sliding_window = model_config.get_sliding_window()
cache_config = CacheConfig( cache_config = CacheConfig(
block_size=self.block_size, block_size=self.block_size,
gpu_memory_utilization=self.gpu_memory_utilization, gpu_memory_utilization=self.gpu_memory_utilization,
@ -1088,7 +1096,7 @@ class EngineArgs:
cache_dtype=self.kv_cache_dtype, cache_dtype=self.kv_cache_dtype,
is_attention_free=model_config.is_attention_free, is_attention_free=model_config.is_attention_free,
num_gpu_blocks_override=self.num_gpu_blocks_override, num_gpu_blocks_override=self.num_gpu_blocks_override,
sliding_window=model_config.get_sliding_window(), sliding_window=sliding_window,
enable_prefix_caching=self.enable_prefix_caching, enable_prefix_caching=self.enable_prefix_caching,
prefix_caching_hash_algo=self.prefix_caching_hash_algo, prefix_caching_hash_algo=self.prefix_caching_hash_algo,
cpu_offload_gb=self.cpu_offload_gb, cpu_offload_gb=self.cpu_offload_gb,

View File

@ -182,21 +182,13 @@ class CohereAttention(nn.Module):
) )
# Model v2 has interleaved sliding windows, v1 does not # Model v2 has interleaved sliding windows, v1 does not
interleaved_sliding_window = getattr(config, self.v1 = isinstance(config, CohereConfig)
"interleaved_sliding_window",
None)
self.v1 = interleaved_sliding_window is None
self.sliding_window = None
if not self.v1:
layer_idx = extract_layer_index(prefix) layer_idx = extract_layer_index(prefix)
layer_has_sliding_window = ( if config.layer_types[layer_idx] == "sliding_attention":
getattr(config, "sliding_window_pattern", False) and self.sliding_window = config.sliding_window
(layer_idx + 1) % self.config.sliding_window_pattern
!= 0) or (getattr(config, "layer_types", False)
and config.layer_types[layer_idx] == "sliding_attention")
self.sliding_window = (interleaved_sliding_window
or config.sliding_window
if layer_has_sliding_window else None)
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,

View File

@ -159,25 +159,12 @@ class Exaone4Attention(nn.Module):
if quant_config is not None and quant_config.get_name() == "gguf": if quant_config is not None and quant_config.get_name() == "gguf":
is_neox_style = False is_neox_style = False
self.apply_all_layers = False # apply rotary embeddings to every layer.
layer_idx = extract_layer_index(prefix) layer_idx = extract_layer_index(prefix)
interleaved_sliding_window = getattr(config, is_sliding = config.layer_types[layer_idx] == "sliding_attention"
"interleaved_sliding_window", self.sliding_window = config.sliding_window if is_sliding else None
4096)
sliding_window_pattern = getattr(config, "sliding_window_pattern",
"LLLG")
if sliding_window_pattern: # apply rotary embeddings to every layer
layer_has_sliding_window = ( self.apply_all_layers = not is_sliding
layer_idx + 1) % sliding_window_pattern.__len__() != 0
else:
layer_has_sliding_window = False
self.apply_all_layers = True
if layer_has_sliding_window:
self.sliding_window = interleaved_sliding_window
else:
self.sliding_window = None
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,

View File

@ -144,13 +144,10 @@ class Gemma2Attention(nn.Module):
is_neox_style=True, is_neox_style=True,
) )
# reference:
# https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa
layer_idx = extract_layer_index(prefix) layer_idx = extract_layer_index(prefix)
use_sliding_window = (layer_idx % 2 == 0 and getattr( is_sliding = config.layer_types[layer_idx] == "sliding_attention"
config, "interleaved_sliding_window", None) is not None) sliding_window = config.sliding_window if is_sliding else None
sliding_window = config.interleaved_sliding_window if \
use_sliding_window else None
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,

View File

@ -146,25 +146,19 @@ class Gemma3Attention(nn.Module):
self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
# TODO(woosuk): Add reference to the original HF implementation.
layer_idx = extract_layer_index(prefix) layer_idx = extract_layer_index(prefix)
self.is_sliding = (getattr( self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
config, "interleaved_sliding_window", None) is not None and (bool( sliding_window = config.sliding_window if self.is_sliding else None
(layer_idx + 1) % config.sliding_window_pattern))) or (
getattr(config, "layer_types", None) is not None
and config.layer_types[layer_idx] == "sliding_attention")
# Initialize the rotary embedding. # Initialize the rotary embedding.
if self.is_sliding: if self.is_sliding:
# Local attention. Override the values in config.json. # Local attention. Override the values in config.json.
self.rope_theta = config.rope_local_base_freq self.rope_theta = config.rope_local_base_freq
self.rope_scaling = {"rope_type": "default"} self.rope_scaling = {"rope_type": "default"}
self.sliding_window = (config.interleaved_sliding_window
or config.sliding_window)
else: else:
# Global attention. Use the values in config.json. # Global attention. Use the values in config.json.
self.rope_theta = config.rope_theta self.rope_theta = config.rope_theta
self.rope_scaling = config.rope_scaling self.rope_scaling = config.rope_scaling
self.sliding_window = None
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim, rotary_dim=self.head_dim,
@ -182,7 +176,7 @@ class Gemma3Attention(nn.Module):
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
logits_soft_cap=attn_logits_soft_cap, logits_soft_cap=attn_logits_soft_cap,
per_layer_sliding_window=self.sliding_window, per_layer_sliding_window=sliding_window,
prefix=f"{prefix}.attn") prefix=f"{prefix}.attn")
def forward( def forward(

View File

@ -502,8 +502,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.sliding_window = getattr(config.text_config,
"interleaved_sliding_window", None)
self.vision_tower = SiglipVisionModel(config.vision_config, self.vision_tower = SiglipVisionModel(config.vision_config,
quant_config, quant_config,
@ -690,11 +688,11 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask) global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
global_attn_masks.append(global_attn_mask) global_attn_masks.append(global_attn_mask)
if self.sliding_window is not None: if (sliding_window := self.config.sliding_window) is not None:
# Create a local causal mask with sliding window (1024). # Create a local causal mask with sliding window (1024).
local_attn_mask = torch.ones_like(global_attn_mask) local_attn_mask = torch.ones_like(global_attn_mask)
local_attn_mask = torch.tril(local_attn_mask, local_attn_mask = torch.tril(local_attn_mask,
diagonal=-self.sliding_window) diagonal=-sliding_window)
local_attn_mask = torch.where(local_attn_mask == 0, local_attn_mask = torch.where(local_attn_mask == 0,
global_attn_mask, float("-inf")) global_attn_mask, float("-inf"))
local_attn_masks.append(local_attn_mask) local_attn_masks.append(local_attn_mask)

View File

@ -313,17 +313,16 @@ class Gemma3nAttention(nn.Module):
has_weight=False) has_weight=False)
layer_idx = extract_layer_index(prefix) layer_idx = extract_layer_index(prefix)
is_sliding = config.layer_types[layer_idx] == "sliding_attention"
self.sliding_window = config.sliding_window if is_sliding else None
is_sliding_window = ( # Initialize the rotary embedding.
getattr(config, "interleaved_sliding_window", None) is not None if is_sliding:
and config.layer_types[layer_idx] == "sliding_attention") # Local attention. Override the values in config.json.
if is_sliding_window:
self.sliding_window = config.interleaved_sliding_window
rope_theta = config.rope_local_base_freq rope_theta = config.rope_local_base_freq
rope_scaling = {"rope_type": "default"} rope_scaling = {"rope_type": "default"}
else: else:
self.sliding_window = None # Global attention. Use the values in config.json.
rope_theta = config.rope_theta rope_theta = config.rope_theta
rope_scaling = config.rope_scaling rope_scaling = config.rope_scaling

View File

@ -248,9 +248,7 @@ class GritLM(LlamaForCausalLM, SupportsV0Only):
vllm_config.cache_config.sliding_window = None vllm_config.cache_config.sliding_window = None
for attr in ("sliding_window", "interleaved_sliding_window"): hf_config.sliding_window = None
if hasattr(hf_config, attr):
delattr(hf_config, attr)
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)

View File

@ -167,18 +167,11 @@ class LlamaAttention(nn.Module):
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
quant_config=quant_config) quant_config=quant_config)
if hasattr(config, "interleaved_sliding_window"):
interleaved_sliding_window = config.interleaved_sliding_window
if isinstance(interleaved_sliding_window, int):
sliding_window = interleaved_sliding_window
elif isinstance(interleaved_sliding_window, list):
sw_idx = layer_idx % len(interleaved_sliding_window)
sliding_window = interleaved_sliding_window[sw_idx]
else:
raise ValueError(
f"{type(interleaved_sliding_window)} is not supported.")
else:
sliding_window = None sliding_window = None
if layer_types := getattr(config, "layer_types", None):
is_sliding = layer_types[layer_idx] == "sliding_attention"
if is_sliding:
sliding_window = config.sliding_window
self.attn = Attention( self.attn = Attention(
self.num_heads, self.num_heads,

View File

@ -116,13 +116,8 @@ class SambaYAttention(nn.Module):
self.Wqkv = nn.Linear(self.hidden_size, op_size, bias=True) self.Wqkv = nn.Linear(self.hidden_size, op_size, bias=True)
# disable sliding window for the second half of the model # disable sliding window for the second half of the model
sliding_window = config.interleaved_sliding_window[layer_idx] is_sliding = config.layer_types[layer_idx] == "sliding_attention"
if layer_idx >= config.num_hidden_layers // 2: sliding_window = config.sliding_window if is_sliding else None
assert sliding_window is None, \
"sliding_window must be none for the second decoder"
else:
assert sliding_window is not None, \
"sliding_window must be set for the first decoder"
assert self.num_heads % 2 == 0, 'num_heads should be even' assert self.num_heads % 2 == 0, 'num_heads should be even'
assert self.num_key_value_heads % 2 == 0, 'num_heads should be even' assert self.num_key_value_heads % 2 == 0, 'num_heads should be even'

View File

@ -49,6 +49,7 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import is_interleaved
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
@ -285,8 +286,7 @@ class Qwen2Model(nn.Module):
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
# TODO (@robertgshaw2): see if this can be moved out # TODO (@robertgshaw2): see if this can be moved out
if (cache_config.sliding_window is not None if is_interleaved(vllm_config.model_config.hf_text_config):
and hasattr(config, "max_window_layers")):
assert config.max_window_layers == config.num_hidden_layers, ( assert config.max_window_layers == config.num_hidden_layers, (
"Sliding window for some but all layers is not supported. " "Sliding window for some but all layers is not supported. "
"This model uses sliding window but `max_window_layers` = {} " "This model uses sliding window but `max_window_layers` = {} "

View File

@ -16,7 +16,7 @@
# limitations under the License. # limitations under the License.
"""Wrapper around `transformers` models""" """Wrapper around `transformers` models"""
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping
from contextlib import contextmanager, nullcontext from contextlib import contextmanager
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
import regex as re import regex as re
@ -382,33 +382,6 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
) )
class ConfigOverride:
"""Context manager to temporarily override config attributes."""
def __init__(self, config: PretrainedConfig, **kwargs):
self.config = config
self.kwargs = kwargs
self.kwargs_original = {}
self.kwargs_delete = set()
def __enter__(self):
"""Override config attributes."""
for key, value in self.kwargs.items():
if not hasattr(self.config, key):
self.kwargs_delete.add(key)
self.kwargs_original[key] = getattr(self.config, key, None)
setattr(self.config, key, value)
return self.config
def __exit__(self, exc_type, exc_value, traceback):
"""Restore original config attributes."""
for key, value in self.kwargs_original.items():
if key in self.kwargs_delete:
delattr(self.config, key)
else:
setattr(self.config, key, value)
class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
embedding_padding_modules = ["lm_head"] embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens" embedding_modules = ["embed_tokens"
@ -434,21 +407,11 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
# To be updated in child classes for use in `load_weights` # To be updated in child classes for use in `load_weights`
self.skip_prefixes: Optional[list[str]] = None self.skip_prefixes: Optional[list[str]] = None
# vLLM handles interleaved sliding window attention by creating a new
# interleaved_sliding_window attribute and deleting the sliding_window
# attribute. This breaks the constructors in Transformers so we
# temporarily add the attribute back to construct the model.
config_override = nullcontext()
if hasattr(self.config, "interleaved_sliding_window"):
config_override = ConfigOverride(
self.config,
sliding_window=self.config.interleaved_sliding_window)
# Set correct attn and init on "meta" to delay allocating GPU tensors # Set correct attn and init on "meta" to delay allocating GPU tensors
# TODO: @raushan, use the public `model.set_attn_implementation()` # TODO: @raushan, use the public `model.set_attn_implementation()`
# method once its checks are fixed in Transformers. # method once its checks are fixed in Transformers.
self.text_config._attn_implementation = "vllm" self.text_config._attn_implementation = "vllm"
with init_on_device_without_buffers("meta"), config_override: with init_on_device_without_buffers("meta"):
self.model: PreTrainedModel = AutoModel.from_config( self.model: PreTrainedModel = AutoModel.from_config(
self.config, self.config,
torch_dtype=self.model_config.dtype, torch_dtype=self.model_config.dtype,
@ -575,11 +538,10 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
attention_instances = {} attention_instances = {}
for i in range(start, end): for i in range(start, end):
# Handle interleaved sliding window attention # Handle interleaved sliding window attention
sliding_window = None per_layer_sliding_window = None
if (hasattr(self.config, "interleaved_sliding_window") if (hasattr(self.config, "layer_types")
and hasattr(self.config, "sliding_window_pattern") and self.config.layer_types[i] == "sliding_attention"):
and ((i + 1) % self.config.sliding_window_pattern > 0)): per_layer_sliding_window = self.config.sliding_window
sliding_window = self.config.interleaved_sliding_window
attention_instances[i] = Attention( attention_instances[i] = Attention(
num_heads=num_heads, num_heads=num_heads,
@ -590,7 +552,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
num_kv_heads=num_kv_heads, num_kv_heads=num_kv_heads,
cache_config=self.cache_config, cache_config=self.cache_config,
quant_config=self.quant_config, quant_config=self.quant_config,
per_layer_sliding_window=sliding_window, per_layer_sliding_window=per_layer_sliding_window,
prefix=f"{i}.attn") prefix=f"{i}.attn")
return attention_instances return attention_instances

View File

@ -280,6 +280,17 @@ def is_encoder_decoder(config: PretrainedConfig) -> bool:
return getattr(config, "is_encoder_decoder", False) return getattr(config, "is_encoder_decoder", False)
def is_interleaved(config: PretrainedConfig) -> bool:
"""
Detect if the model with this config is used with interleaved attention.
"""
text_config = config.get_text_config()
if layer_types := getattr(text_config, "layer_types", None):
interleaved_types = {"full_attention", "sliding_attention"}
return interleaved_types.issubset(layer_types)
return False
def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig: def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig:
"""Remap config attributes to match the expected names.""" """Remap config attributes to match the expected names."""
for old_attr, new_attr in _CONFIG_ATTRS_MAPPING.items(): for old_attr, new_attr in _CONFIG_ATTRS_MAPPING.items():
@ -423,6 +434,23 @@ def get_config(
raise e raise e
config = _maybe_remap_hf_config_attrs(config) config = _maybe_remap_hf_config_attrs(config)
# Phi4Flash misuses this config as list[int]. Convert it to int and add
# the layer_types list[str] to make it HF compatible
if (config.model_type == "phi4flash"):
# TODO: Remove after the following PR is merged:
# https://huggingface.co/microsoft/Phi-4-mini-flash-reasoning/discussions/6
if not hasattr(config, "layer_types"):
config.layer_types = [
"sliding_attention" if i < config.num_hidden_layers // 2
and i % 2 == 1 else "full_attention"
for i in range(config.num_hidden_layers)
]
# TODO: Remove after the following PR is merged:
# https://huggingface.co/microsoft/Phi-4-mini-flash-reasoning/discussions/7
if isinstance(config.sliding_window, list):
config.sliding_window = next(
filter(None, config.sliding_window), None)
elif config_format == ConfigFormat.MISTRAL: elif config_format == ConfigFormat.MISTRAL:
# This function loads a params.json config which # This function loads a params.json config which
# should be used when loading models in mistral format # should be used when loading models in mistral format
@ -434,6 +462,18 @@ def get_config(
config_dict["max_position_embeddings"] = max_position_embeddings config_dict["max_position_embeddings"] = max_position_embeddings
config = adapt_config_dict(config_dict) config = adapt_config_dict(config_dict)
# Mistral configs may define sliding_window as list[int]. Convert it
# to int and add the layer_types list[str] to make it HF compatible
if ((sliding_window := getattr(config, "sliding_window", None))
and isinstance(sliding_window, list)):
pattern_repeats = config.num_hidden_layers // len(sliding_window)
layer_types = sliding_window * pattern_repeats
config.layer_types = [
"full_attention" if layer_type is None else "sliding_attention"
for layer_type in layer_types
]
config.sliding_window = next(filter(None, sliding_window), None)
else: else:
supported_formats = [ supported_formats = [
fmt.value for fmt in ConfigFormat if fmt != ConfigFormat.AUTO fmt.value for fmt in ConfigFormat if fmt != ConfigFormat.AUTO