Refactor sliding window configuration to Transformers best practice (#21927)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-08-10 04:50:48 +01:00 committed by GitHub
parent 2a84fb422f
commit c49848396d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 123 additions and 231 deletions

View File

@ -117,7 +117,7 @@ For models with interleaving sliding windows (e.g. `google/gemma-2-2b-it` and `m
To support a model with interleaving sliding windows, we need to take care of the following details:
- Make sure the model's `config.json` contains `sliding_window_pattern`. vLLM then sets `self.hf_text_config.interleaved_sliding_window` to the value of `self.hf_text_config.sliding_window` and deletes `sliding_window` from `self.hf_text_config`. The model will then be treated as a full-attention model.
- Make sure the model's `config.json` contains `layer_types`.
- In the modeling code, parse the correct sliding window value for every layer, and pass it to the attention layer's `per_layer_sliding_window` argument. For reference, check [this line](https://github.com/vllm-project/vllm/blob/996357e4808ca5eab97d4c97c7d25b3073f46aab/vllm/model_executor/models/llama.py#L171).
With these two steps, interleave sliding windows should work with the model.

View File

@ -200,28 +200,6 @@ def test_disable_sliding_window(model_id_expected):
assert model_config.max_model_len == expected
def test_get_sliding_window():
TEST_SLIDING_WINDOW = 4096
# Test that the sliding window is correctly computed.
# For Qwen1.5/Qwen2, get_sliding_window() should be None
# when use_sliding_window is False.
qwen2_model_config = ModelConfig("Qwen/Qwen1.5-7B")
qwen2_model_config.hf_config.use_sliding_window = False
qwen2_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
assert qwen2_model_config.get_sliding_window() is None
qwen2_model_config.hf_config.use_sliding_window = True
assert qwen2_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
mistral_model_config = ModelConfig("mistralai/Mistral-7B-v0.1")
mistral_model_config.hf_config.sliding_window = None
assert mistral_model_config.get_sliding_window() is None
mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
def test_get_pooling_config():

View File

@ -40,8 +40,9 @@ from vllm.transformers_utils.config import (
ConfigFormat, get_config, get_hf_image_processor_config,
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
maybe_override_with_speculators_target_model, try_get_generation_config,
try_get_safetensors_metadata, try_get_tokenizer_config, uses_mrope)
is_interleaved, maybe_override_with_speculators_target_model,
try_get_generation_config, try_get_safetensors_metadata,
try_get_tokenizer_config, uses_mrope)
from vllm.transformers_utils.s3_utils import S3Model
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
# yapf conflicts with isort for this block
@ -714,53 +715,31 @@ class ModelConfig:
revision=self.revision,
)
# Workaround for Gemma 2 which uses interleaved sliding window
# attention, but it's not specified in its config.
# TODO: remove this when Gemma 2 config updated in HuggingFace.
if self.hf_text_config.model_type == "gemma2":
self.hf_text_config.sliding_window_pattern = 2
# TODO: remove this when Gemma 3n config updated in HuggingFace.
if self.hf_text_config.model_type == "gemma3n_text":
# 4 sliding window attention followed by 1 full attention
self.hf_text_config.sliding_window_pattern = "LLLLG"
sliding_window = getattr(self.hf_text_config, "sliding_window", None)
sliding_window_pattern = getattr(self.hf_text_config,
"sliding_window_pattern", None)
has_interleaved_attention = sliding_window_pattern is not None or (
isinstance(sliding_window, list))
if not self.disable_sliding_window and has_interleaved_attention:
if not envs.VLLM_USE_V1 and (backend := envs.VLLM_ATTENTION_BACKEND
) in ("XFORMERS", "FLASHINFER"):
sliding_window_len_min = get_min_sliding_window(
self.hf_text_config.sliding_window)
logger.warning_once(
"%s has interleaved attention, which is currently not supported by the %s backend. Disabling sliding window and capping the max length to the sliding window size (%d).", # noqa: E501
self.hf_text_config.model_type,
backend,
sliding_window_len_min,
)
self.disable_sliding_window = True
else:
# for a model with interleaved attention,
# the scheduler and the model treat it as full attention
# (i.e., not dropping any tokens outside the window).
# only the attention layer itself is aware of the sliding
# window, and use the window size to compute the attention.
self.hf_text_config.interleaved_sliding_window = sliding_window
if hasattr(self.hf_text_config, "sliding_window"):
delattr(self.hf_text_config, "sliding_window")
sliding_window = None
# Interleaved attention is not supported by some backends in V0
if (not self.disable_sliding_window
and is_interleaved(self.hf_text_config)
and not envs.VLLM_USE_V1
and (backend := envs.VLLM_ATTENTION_BACKEND)
in ("XFORMERS", "FLASHINFER")):
logger.warning_once(
"%s has interleaved attention, which is currently not "
"supported by the %s backend. Disabling sliding window and "
"capping the max length to the sliding window size (%d).",
self.hf_text_config.model_type,
backend,
self.hf_text_config.sliding_window,
)
self.disable_sliding_window = True
self.original_max_model_len = self.max_model_len
self.max_model_len = self.get_and_verify_max_len(self.max_model_len)
self.multimodal_config = self._init_multimodal_config()
if self.disable_sliding_window:
# Set after get_and_verify_max_len to ensure that max_model_len
# can be correctly capped to sliding window size
self.hf_text_config.sliding_window = None
if not self.skip_tokenizer_init:
self._verify_tokenizer_mode()
@ -1322,27 +1301,10 @@ class ModelConfig:
if self.use_async_output_proc:
self.use_async_output_proc = False
def get_hf_config_sliding_window(
self) -> Union[Optional[int], list[Optional[int]]]:
"""Get the sliding window size, or None if disabled."""
# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
# addition to sliding window size. We check if that field is present
# and if it's False, return None.
if (hasattr(self.hf_text_config, "use_sliding_window")
and not self.hf_text_config.use_sliding_window):
return None
def get_sliding_window(self) -> Optional[int]:
"""Get the sliding window size from the HF text config if present."""
return getattr(self.hf_text_config, "sliding_window", None)
def get_sliding_window(self) -> Optional[Union[int, list[Optional[int]]]]:
"""Get the sliding window size, or None if disabled.
"""
# If user disables sliding window, return None.
if self.disable_sliding_window:
return None
# Otherwise get the value from the hf config.
return self.get_hf_config_sliding_window()
def get_vocab_size(self) -> int:
return getattr(self.hf_text_config, "vocab_size", 0)
@ -1762,7 +1724,7 @@ class ModelConfig:
tokenizer_config=tokenizer_config,
max_model_len=max_model_len,
disable_sliding_window=self.disable_sliding_window,
sliding_window_len=self.get_hf_config_sliding_window(),
sliding_window=self.get_sliding_window(),
spec_target_max_model_len=self.spec_target_max_model_len,
encoder_config=self.encoder_config)
logger.info("Using max model len %s", max_model_len)
@ -3305,7 +3267,7 @@ def _get_and_verify_max_len(
tokenizer_config: Optional[dict],
max_model_len: Optional[int],
disable_sliding_window: bool,
sliding_window_len: Optional[Union[int, list[Optional[int]]]],
sliding_window: Optional[int],
spec_target_max_model_len: Optional[int] = None,
encoder_config: Optional[Any] = None,
) -> int:
@ -3344,13 +3306,10 @@ def _get_and_verify_max_len(
# If sliding window is manually disabled, max_length should be less
# than the sliding window length in the model config.
if disable_sliding_window and sliding_window_len is not None:
sliding_window_len_min = get_min_sliding_window(sliding_window_len)
max_len_key = "sliding_window" \
if sliding_window_len_min < derived_max_model_len else max_len_key
derived_max_model_len = min(derived_max_model_len,
sliding_window_len_min)
if (disable_sliding_window and sliding_window is not None
and sliding_window < derived_max_model_len):
max_len_key = "sliding_window"
derived_max_model_len = sliding_window
# Consider model_max_length in tokenizer_config
if tokenizer_config:
@ -3451,14 +3410,6 @@ def _get_and_verify_max_len(
return int(max_model_len)
def get_min_sliding_window(
sliding_window: Union[int, list[Optional[int]]]) -> int:
if isinstance(sliding_window, list):
return min(s for s in sliding_window if s is not None)
return sliding_window
def get_served_model_name(model: str,
served_model_name: Optional[Union[str, list[str]]]):
"""

View File

@ -39,6 +39,7 @@ from vllm.plugins import load_general_plugins
from vllm.ray.lazy_utils import is_ray_initialized
from vllm.reasoning import ReasoningParserManager
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.config import is_interleaved
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
GiB_bytes, get_ip, is_in_ray_actor)
@ -1081,6 +1082,13 @@ class EngineArgs:
"DualChunkFlashAttention is not supported on V1 engine. "
"To run the model in V0 engine, try set 'VLLM_USE_V1=0'")
sliding_window: Optional[int] = None
if not is_interleaved(model_config.hf_text_config):
# Only set CacheConfig.sliding_window if the model is all sliding
# window. Otherwise CacheConfig.sliding_window will override the
# global layers in interleaved sliding window models.
sliding_window = model_config.get_sliding_window()
cache_config = CacheConfig(
block_size=self.block_size,
gpu_memory_utilization=self.gpu_memory_utilization,
@ -1088,7 +1096,7 @@ class EngineArgs:
cache_dtype=self.kv_cache_dtype,
is_attention_free=model_config.is_attention_free,
num_gpu_blocks_override=self.num_gpu_blocks_override,
sliding_window=model_config.get_sliding_window(),
sliding_window=sliding_window,
enable_prefix_caching=self.enable_prefix_caching,
prefix_caching_hash_algo=self.prefix_caching_hash_algo,
cpu_offload_gb=self.cpu_offload_gb,

View File

@ -182,21 +182,13 @@ class CohereAttention(nn.Module):
)
# Model v2 has interleaved sliding windows, v1 does not
interleaved_sliding_window = getattr(config,
"interleaved_sliding_window",
None)
self.v1 = interleaved_sliding_window is None
self.v1 = isinstance(config, CohereConfig)
layer_idx = extract_layer_index(prefix)
layer_has_sliding_window = (
getattr(config, "sliding_window_pattern", False) and
(layer_idx + 1) % self.config.sliding_window_pattern
!= 0) or (getattr(config, "layer_types", False)
and config.layer_types[layer_idx] == "sliding_attention")
self.sliding_window = (interleaved_sliding_window
or config.sliding_window
if layer_has_sliding_window else None)
self.sliding_window = None
if not self.v1:
layer_idx = extract_layer_index(prefix)
if config.layer_types[layer_idx] == "sliding_attention":
self.sliding_window = config.sliding_window
self.attn = Attention(self.num_heads,
self.head_dim,

View File

@ -159,25 +159,12 @@ class Exaone4Attention(nn.Module):
if quant_config is not None and quant_config.get_name() == "gguf":
is_neox_style = False
self.apply_all_layers = False # apply rotary embeddings to every layer.
layer_idx = extract_layer_index(prefix)
interleaved_sliding_window = getattr(config,
"interleaved_sliding_window",
4096)
sliding_window_pattern = getattr(config, "sliding_window_pattern",
"LLLG")
is_sliding = config.layer_types[layer_idx] == "sliding_attention"
self.sliding_window = config.sliding_window if is_sliding else None
if sliding_window_pattern:
layer_has_sliding_window = (
layer_idx + 1) % sliding_window_pattern.__len__() != 0
else:
layer_has_sliding_window = False
self.apply_all_layers = True
if layer_has_sliding_window:
self.sliding_window = interleaved_sliding_window
else:
self.sliding_window = None
# apply rotary embeddings to every layer
self.apply_all_layers = not is_sliding
self.rotary_emb = get_rope(
self.head_dim,

View File

@ -144,13 +144,10 @@ class Gemma2Attention(nn.Module):
is_neox_style=True,
)
# reference:
# https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa
layer_idx = extract_layer_index(prefix)
use_sliding_window = (layer_idx % 2 == 0 and getattr(
config, "interleaved_sliding_window", None) is not None)
sliding_window = config.interleaved_sliding_window if \
use_sliding_window else None
is_sliding = config.layer_types[layer_idx] == "sliding_attention"
sliding_window = config.sliding_window if is_sliding else None
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,

View File

@ -146,25 +146,19 @@ class Gemma3Attention(nn.Module):
self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
# TODO(woosuk): Add reference to the original HF implementation.
layer_idx = extract_layer_index(prefix)
self.is_sliding = (getattr(
config, "interleaved_sliding_window", None) is not None and (bool(
(layer_idx + 1) % config.sliding_window_pattern))) or (
getattr(config, "layer_types", None) is not None
and config.layer_types[layer_idx] == "sliding_attention")
self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
sliding_window = config.sliding_window if self.is_sliding else None
# Initialize the rotary embedding.
if self.is_sliding:
# Local attention. Override the values in config.json.
self.rope_theta = config.rope_local_base_freq
self.rope_scaling = {"rope_type": "default"}
self.sliding_window = (config.interleaved_sliding_window
or config.sliding_window)
else:
# Global attention. Use the values in config.json.
self.rope_theta = config.rope_theta
self.rope_scaling = config.rope_scaling
self.sliding_window = None
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
@ -182,7 +176,7 @@ class Gemma3Attention(nn.Module):
cache_config=cache_config,
quant_config=quant_config,
logits_soft_cap=attn_logits_soft_cap,
per_layer_sliding_window=self.sliding_window,
per_layer_sliding_window=sliding_window,
prefix=f"{prefix}.attn")
def forward(

View File

@ -502,8 +502,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
self.config = config
self.quant_config = quant_config
self.multimodal_config = multimodal_config
self.sliding_window = getattr(config.text_config,
"interleaved_sliding_window", None)
self.vision_tower = SiglipVisionModel(config.vision_config,
quant_config,
@ -690,11 +688,11 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
global_attn_masks.append(global_attn_mask)
if self.sliding_window is not None:
if (sliding_window := self.config.sliding_window) is not None:
# Create a local causal mask with sliding window (1024).
local_attn_mask = torch.ones_like(global_attn_mask)
local_attn_mask = torch.tril(local_attn_mask,
diagonal=-self.sliding_window)
diagonal=-sliding_window)
local_attn_mask = torch.where(local_attn_mask == 0,
global_attn_mask, float("-inf"))
local_attn_masks.append(local_attn_mask)

View File

@ -313,17 +313,16 @@ class Gemma3nAttention(nn.Module):
has_weight=False)
layer_idx = extract_layer_index(prefix)
is_sliding = config.layer_types[layer_idx] == "sliding_attention"
self.sliding_window = config.sliding_window if is_sliding else None
is_sliding_window = (
getattr(config, "interleaved_sliding_window", None) is not None
and config.layer_types[layer_idx] == "sliding_attention")
if is_sliding_window:
self.sliding_window = config.interleaved_sliding_window
# Initialize the rotary embedding.
if is_sliding:
# Local attention. Override the values in config.json.
rope_theta = config.rope_local_base_freq
rope_scaling = {"rope_type": "default"}
else:
self.sliding_window = None
# Global attention. Use the values in config.json.
rope_theta = config.rope_theta
rope_scaling = config.rope_scaling

View File

@ -248,9 +248,7 @@ class GritLM(LlamaForCausalLM, SupportsV0Only):
vllm_config.cache_config.sliding_window = None
for attr in ("sliding_window", "interleaved_sliding_window"):
if hasattr(hf_config, attr):
delattr(hf_config, attr)
hf_config.sliding_window = None
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)

View File

@ -167,18 +167,11 @@ class LlamaAttention(nn.Module):
rope_scaling=rope_scaling,
quant_config=quant_config)
if hasattr(config, "interleaved_sliding_window"):
interleaved_sliding_window = config.interleaved_sliding_window
if isinstance(interleaved_sliding_window, int):
sliding_window = interleaved_sliding_window
elif isinstance(interleaved_sliding_window, list):
sw_idx = layer_idx % len(interleaved_sliding_window)
sliding_window = interleaved_sliding_window[sw_idx]
else:
raise ValueError(
f"{type(interleaved_sliding_window)} is not supported.")
else:
sliding_window = None
sliding_window = None
if layer_types := getattr(config, "layer_types", None):
is_sliding = layer_types[layer_idx] == "sliding_attention"
if is_sliding:
sliding_window = config.sliding_window
self.attn = Attention(
self.num_heads,

View File

@ -116,13 +116,8 @@ class SambaYAttention(nn.Module):
self.Wqkv = nn.Linear(self.hidden_size, op_size, bias=True)
# disable sliding window for the second half of the model
sliding_window = config.interleaved_sliding_window[layer_idx]
if layer_idx >= config.num_hidden_layers // 2:
assert sliding_window is None, \
"sliding_window must be none for the second decoder"
else:
assert sliding_window is not None, \
"sliding_window must be set for the first decoder"
is_sliding = config.layer_types[layer_idx] == "sliding_attention"
sliding_window = config.sliding_window if is_sliding else None
assert self.num_heads % 2 == 0, 'num_heads should be even'
assert self.num_key_value_heads % 2 == 0, 'num_heads should be even'

View File

@ -49,6 +49,7 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import is_interleaved
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
@ -285,8 +286,7 @@ class Qwen2Model(nn.Module):
quant_config = vllm_config.quant_config
# TODO (@robertgshaw2): see if this can be moved out
if (cache_config.sliding_window is not None
and hasattr(config, "max_window_layers")):
if is_interleaved(vllm_config.model_config.hf_text_config):
assert config.max_window_layers == config.num_hidden_layers, (
"Sliding window for some but all layers is not supported. "
"This model uses sliding window but `max_window_layers` = {} "

View File

@ -16,7 +16,7 @@
# limitations under the License.
"""Wrapper around `transformers` models"""
from collections.abc import Iterable, Mapping
from contextlib import contextmanager, nullcontext
from contextlib import contextmanager
from typing import Literal, Optional, Union
import regex as re
@ -382,33 +382,6 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
)
class ConfigOverride:
"""Context manager to temporarily override config attributes."""
def __init__(self, config: PretrainedConfig, **kwargs):
self.config = config
self.kwargs = kwargs
self.kwargs_original = {}
self.kwargs_delete = set()
def __enter__(self):
"""Override config attributes."""
for key, value in self.kwargs.items():
if not hasattr(self.config, key):
self.kwargs_delete.add(key)
self.kwargs_original[key] = getattr(self.config, key, None)
setattr(self.config, key, value)
return self.config
def __exit__(self, exc_type, exc_value, traceback):
"""Restore original config attributes."""
for key, value in self.kwargs_original.items():
if key in self.kwargs_delete:
delattr(self.config, key)
else:
setattr(self.config, key, value)
class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens"
@ -434,21 +407,11 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
# To be updated in child classes for use in `load_weights`
self.skip_prefixes: Optional[list[str]] = None
# vLLM handles interleaved sliding window attention by creating a new
# interleaved_sliding_window attribute and deleting the sliding_window
# attribute. This breaks the constructors in Transformers so we
# temporarily add the attribute back to construct the model.
config_override = nullcontext()
if hasattr(self.config, "interleaved_sliding_window"):
config_override = ConfigOverride(
self.config,
sliding_window=self.config.interleaved_sliding_window)
# Set correct attn and init on "meta" to delay allocating GPU tensors
# TODO: @raushan, use the public `model.set_attn_implementation()`
# method once its checks are fixed in Transformers.
self.text_config._attn_implementation = "vllm"
with init_on_device_without_buffers("meta"), config_override:
with init_on_device_without_buffers("meta"):
self.model: PreTrainedModel = AutoModel.from_config(
self.config,
torch_dtype=self.model_config.dtype,
@ -575,11 +538,10 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
attention_instances = {}
for i in range(start, end):
# Handle interleaved sliding window attention
sliding_window = None
if (hasattr(self.config, "interleaved_sliding_window")
and hasattr(self.config, "sliding_window_pattern")
and ((i + 1) % self.config.sliding_window_pattern > 0)):
sliding_window = self.config.interleaved_sliding_window
per_layer_sliding_window = None
if (hasattr(self.config, "layer_types")
and self.config.layer_types[i] == "sliding_attention"):
per_layer_sliding_window = self.config.sliding_window
attention_instances[i] = Attention(
num_heads=num_heads,
@ -590,7 +552,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
num_kv_heads=num_kv_heads,
cache_config=self.cache_config,
quant_config=self.quant_config,
per_layer_sliding_window=sliding_window,
per_layer_sliding_window=per_layer_sliding_window,
prefix=f"{i}.attn")
return attention_instances

View File

@ -280,6 +280,17 @@ def is_encoder_decoder(config: PretrainedConfig) -> bool:
return getattr(config, "is_encoder_decoder", False)
def is_interleaved(config: PretrainedConfig) -> bool:
"""
Detect if the model with this config is used with interleaved attention.
"""
text_config = config.get_text_config()
if layer_types := getattr(text_config, "layer_types", None):
interleaved_types = {"full_attention", "sliding_attention"}
return interleaved_types.issubset(layer_types)
return False
def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig:
"""Remap config attributes to match the expected names."""
for old_attr, new_attr in _CONFIG_ATTRS_MAPPING.items():
@ -423,6 +434,23 @@ def get_config(
raise e
config = _maybe_remap_hf_config_attrs(config)
# Phi4Flash misuses this config as list[int]. Convert it to int and add
# the layer_types list[str] to make it HF compatible
if (config.model_type == "phi4flash"):
# TODO: Remove after the following PR is merged:
# https://huggingface.co/microsoft/Phi-4-mini-flash-reasoning/discussions/6
if not hasattr(config, "layer_types"):
config.layer_types = [
"sliding_attention" if i < config.num_hidden_layers // 2
and i % 2 == 1 else "full_attention"
for i in range(config.num_hidden_layers)
]
# TODO: Remove after the following PR is merged:
# https://huggingface.co/microsoft/Phi-4-mini-flash-reasoning/discussions/7
if isinstance(config.sliding_window, list):
config.sliding_window = next(
filter(None, config.sliding_window), None)
elif config_format == ConfigFormat.MISTRAL:
# This function loads a params.json config which
# should be used when loading models in mistral format
@ -434,6 +462,18 @@ def get_config(
config_dict["max_position_embeddings"] = max_position_embeddings
config = adapt_config_dict(config_dict)
# Mistral configs may define sliding_window as list[int]. Convert it
# to int and add the layer_types list[str] to make it HF compatible
if ((sliding_window := getattr(config, "sliding_window", None))
and isinstance(sliding_window, list)):
pattern_repeats = config.num_hidden_layers // len(sliding_window)
layer_types = sliding_window * pattern_repeats
config.layer_types = [
"full_attention" if layer_type is None else "sliding_attention"
for layer_type in layer_types
]
config.sliding_window = next(filter(None, sliding_window), None)
else:
supported_formats = [
fmt.value for fmt in ConfigFormat if fmt != ConfigFormat.AUTO