mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-19 23:04:26 +08:00
Refactor sliding window configuration to Transformers best practice (#21927)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
2a84fb422f
commit
c49848396d
@ -117,7 +117,7 @@ For models with interleaving sliding windows (e.g. `google/gemma-2-2b-it` and `m
|
|||||||
|
|
||||||
To support a model with interleaving sliding windows, we need to take care of the following details:
|
To support a model with interleaving sliding windows, we need to take care of the following details:
|
||||||
|
|
||||||
- Make sure the model's `config.json` contains `sliding_window_pattern`. vLLM then sets `self.hf_text_config.interleaved_sliding_window` to the value of `self.hf_text_config.sliding_window` and deletes `sliding_window` from `self.hf_text_config`. The model will then be treated as a full-attention model.
|
- Make sure the model's `config.json` contains `layer_types`.
|
||||||
- In the modeling code, parse the correct sliding window value for every layer, and pass it to the attention layer's `per_layer_sliding_window` argument. For reference, check [this line](https://github.com/vllm-project/vllm/blob/996357e4808ca5eab97d4c97c7d25b3073f46aab/vllm/model_executor/models/llama.py#L171).
|
- In the modeling code, parse the correct sliding window value for every layer, and pass it to the attention layer's `per_layer_sliding_window` argument. For reference, check [this line](https://github.com/vllm-project/vllm/blob/996357e4808ca5eab97d4c97c7d25b3073f46aab/vllm/model_executor/models/llama.py#L171).
|
||||||
|
|
||||||
With these two steps, interleave sliding windows should work with the model.
|
With these two steps, interleave sliding windows should work with the model.
|
||||||
|
|||||||
@ -200,28 +200,6 @@ def test_disable_sliding_window(model_id_expected):
|
|||||||
assert model_config.max_model_len == expected
|
assert model_config.max_model_len == expected
|
||||||
|
|
||||||
|
|
||||||
def test_get_sliding_window():
|
|
||||||
TEST_SLIDING_WINDOW = 4096
|
|
||||||
# Test that the sliding window is correctly computed.
|
|
||||||
# For Qwen1.5/Qwen2, get_sliding_window() should be None
|
|
||||||
# when use_sliding_window is False.
|
|
||||||
qwen2_model_config = ModelConfig("Qwen/Qwen1.5-7B")
|
|
||||||
|
|
||||||
qwen2_model_config.hf_config.use_sliding_window = False
|
|
||||||
qwen2_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
|
|
||||||
assert qwen2_model_config.get_sliding_window() is None
|
|
||||||
|
|
||||||
qwen2_model_config.hf_config.use_sliding_window = True
|
|
||||||
assert qwen2_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
|
|
||||||
|
|
||||||
mistral_model_config = ModelConfig("mistralai/Mistral-7B-v0.1")
|
|
||||||
mistral_model_config.hf_config.sliding_window = None
|
|
||||||
assert mistral_model_config.get_sliding_window() is None
|
|
||||||
|
|
||||||
mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
|
|
||||||
assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||||
reason="Xformers backend is not supported on ROCm.")
|
reason="Xformers backend is not supported on ROCm.")
|
||||||
def test_get_pooling_config():
|
def test_get_pooling_config():
|
||||||
|
|||||||
@ -40,8 +40,9 @@ from vllm.transformers_utils.config import (
|
|||||||
ConfigFormat, get_config, get_hf_image_processor_config,
|
ConfigFormat, get_config, get_hf_image_processor_config,
|
||||||
get_hf_text_config, get_pooling_config,
|
get_hf_text_config, get_pooling_config,
|
||||||
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
|
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
|
||||||
maybe_override_with_speculators_target_model, try_get_generation_config,
|
is_interleaved, maybe_override_with_speculators_target_model,
|
||||||
try_get_safetensors_metadata, try_get_tokenizer_config, uses_mrope)
|
try_get_generation_config, try_get_safetensors_metadata,
|
||||||
|
try_get_tokenizer_config, uses_mrope)
|
||||||
from vllm.transformers_utils.s3_utils import S3Model
|
from vllm.transformers_utils.s3_utils import S3Model
|
||||||
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
|
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
@ -714,53 +715,31 @@ class ModelConfig:
|
|||||||
revision=self.revision,
|
revision=self.revision,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Workaround for Gemma 2 which uses interleaved sliding window
|
# Interleaved attention is not supported by some backends in V0
|
||||||
# attention, but it's not specified in its config.
|
if (not self.disable_sliding_window
|
||||||
# TODO: remove this when Gemma 2 config updated in HuggingFace.
|
and is_interleaved(self.hf_text_config)
|
||||||
if self.hf_text_config.model_type == "gemma2":
|
and not envs.VLLM_USE_V1
|
||||||
self.hf_text_config.sliding_window_pattern = 2
|
and (backend := envs.VLLM_ATTENTION_BACKEND)
|
||||||
|
in ("XFORMERS", "FLASHINFER")):
|
||||||
# TODO: remove this when Gemma 3n config updated in HuggingFace.
|
|
||||||
if self.hf_text_config.model_type == "gemma3n_text":
|
|
||||||
# 4 sliding window attention followed by 1 full attention
|
|
||||||
self.hf_text_config.sliding_window_pattern = "LLLLG"
|
|
||||||
|
|
||||||
sliding_window = getattr(self.hf_text_config, "sliding_window", None)
|
|
||||||
sliding_window_pattern = getattr(self.hf_text_config,
|
|
||||||
"sliding_window_pattern", None)
|
|
||||||
has_interleaved_attention = sliding_window_pattern is not None or (
|
|
||||||
isinstance(sliding_window, list))
|
|
||||||
|
|
||||||
if not self.disable_sliding_window and has_interleaved_attention:
|
|
||||||
if not envs.VLLM_USE_V1 and (backend := envs.VLLM_ATTENTION_BACKEND
|
|
||||||
) in ("XFORMERS", "FLASHINFER"):
|
|
||||||
sliding_window_len_min = get_min_sliding_window(
|
|
||||||
self.hf_text_config.sliding_window)
|
|
||||||
|
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"%s has interleaved attention, which is currently not supported by the %s backend. Disabling sliding window and capping the max length to the sliding window size (%d).", # noqa: E501
|
"%s has interleaved attention, which is currently not "
|
||||||
|
"supported by the %s backend. Disabling sliding window and "
|
||||||
|
"capping the max length to the sliding window size (%d).",
|
||||||
self.hf_text_config.model_type,
|
self.hf_text_config.model_type,
|
||||||
backend,
|
backend,
|
||||||
sliding_window_len_min,
|
self.hf_text_config.sliding_window,
|
||||||
)
|
)
|
||||||
self.disable_sliding_window = True
|
self.disable_sliding_window = True
|
||||||
else:
|
|
||||||
# for a model with interleaved attention,
|
|
||||||
# the scheduler and the model treat it as full attention
|
|
||||||
# (i.e., not dropping any tokens outside the window).
|
|
||||||
# only the attention layer itself is aware of the sliding
|
|
||||||
# window, and use the window size to compute the attention.
|
|
||||||
self.hf_text_config.interleaved_sliding_window = sliding_window
|
|
||||||
|
|
||||||
if hasattr(self.hf_text_config, "sliding_window"):
|
|
||||||
delattr(self.hf_text_config, "sliding_window")
|
|
||||||
|
|
||||||
sliding_window = None
|
|
||||||
|
|
||||||
self.original_max_model_len = self.max_model_len
|
self.original_max_model_len = self.max_model_len
|
||||||
self.max_model_len = self.get_and_verify_max_len(self.max_model_len)
|
self.max_model_len = self.get_and_verify_max_len(self.max_model_len)
|
||||||
self.multimodal_config = self._init_multimodal_config()
|
self.multimodal_config = self._init_multimodal_config()
|
||||||
|
|
||||||
|
if self.disable_sliding_window:
|
||||||
|
# Set after get_and_verify_max_len to ensure that max_model_len
|
||||||
|
# can be correctly capped to sliding window size
|
||||||
|
self.hf_text_config.sliding_window = None
|
||||||
|
|
||||||
if not self.skip_tokenizer_init:
|
if not self.skip_tokenizer_init:
|
||||||
self._verify_tokenizer_mode()
|
self._verify_tokenizer_mode()
|
||||||
|
|
||||||
@ -1322,27 +1301,10 @@ class ModelConfig:
|
|||||||
if self.use_async_output_proc:
|
if self.use_async_output_proc:
|
||||||
self.use_async_output_proc = False
|
self.use_async_output_proc = False
|
||||||
|
|
||||||
def get_hf_config_sliding_window(
|
def get_sliding_window(self) -> Optional[int]:
|
||||||
self) -> Union[Optional[int], list[Optional[int]]]:
|
"""Get the sliding window size from the HF text config if present."""
|
||||||
"""Get the sliding window size, or None if disabled."""
|
|
||||||
|
|
||||||
# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
|
|
||||||
# addition to sliding window size. We check if that field is present
|
|
||||||
# and if it's False, return None.
|
|
||||||
if (hasattr(self.hf_text_config, "use_sliding_window")
|
|
||||||
and not self.hf_text_config.use_sliding_window):
|
|
||||||
return None
|
|
||||||
return getattr(self.hf_text_config, "sliding_window", None)
|
return getattr(self.hf_text_config, "sliding_window", None)
|
||||||
|
|
||||||
def get_sliding_window(self) -> Optional[Union[int, list[Optional[int]]]]:
|
|
||||||
"""Get the sliding window size, or None if disabled.
|
|
||||||
"""
|
|
||||||
# If user disables sliding window, return None.
|
|
||||||
if self.disable_sliding_window:
|
|
||||||
return None
|
|
||||||
# Otherwise get the value from the hf config.
|
|
||||||
return self.get_hf_config_sliding_window()
|
|
||||||
|
|
||||||
def get_vocab_size(self) -> int:
|
def get_vocab_size(self) -> int:
|
||||||
return getattr(self.hf_text_config, "vocab_size", 0)
|
return getattr(self.hf_text_config, "vocab_size", 0)
|
||||||
|
|
||||||
@ -1762,7 +1724,7 @@ class ModelConfig:
|
|||||||
tokenizer_config=tokenizer_config,
|
tokenizer_config=tokenizer_config,
|
||||||
max_model_len=max_model_len,
|
max_model_len=max_model_len,
|
||||||
disable_sliding_window=self.disable_sliding_window,
|
disable_sliding_window=self.disable_sliding_window,
|
||||||
sliding_window_len=self.get_hf_config_sliding_window(),
|
sliding_window=self.get_sliding_window(),
|
||||||
spec_target_max_model_len=self.spec_target_max_model_len,
|
spec_target_max_model_len=self.spec_target_max_model_len,
|
||||||
encoder_config=self.encoder_config)
|
encoder_config=self.encoder_config)
|
||||||
logger.info("Using max model len %s", max_model_len)
|
logger.info("Using max model len %s", max_model_len)
|
||||||
@ -3305,7 +3267,7 @@ def _get_and_verify_max_len(
|
|||||||
tokenizer_config: Optional[dict],
|
tokenizer_config: Optional[dict],
|
||||||
max_model_len: Optional[int],
|
max_model_len: Optional[int],
|
||||||
disable_sliding_window: bool,
|
disable_sliding_window: bool,
|
||||||
sliding_window_len: Optional[Union[int, list[Optional[int]]]],
|
sliding_window: Optional[int],
|
||||||
spec_target_max_model_len: Optional[int] = None,
|
spec_target_max_model_len: Optional[int] = None,
|
||||||
encoder_config: Optional[Any] = None,
|
encoder_config: Optional[Any] = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
@ -3344,13 +3306,10 @@ def _get_and_verify_max_len(
|
|||||||
|
|
||||||
# If sliding window is manually disabled, max_length should be less
|
# If sliding window is manually disabled, max_length should be less
|
||||||
# than the sliding window length in the model config.
|
# than the sliding window length in the model config.
|
||||||
if disable_sliding_window and sliding_window_len is not None:
|
if (disable_sliding_window and sliding_window is not None
|
||||||
|
and sliding_window < derived_max_model_len):
|
||||||
sliding_window_len_min = get_min_sliding_window(sliding_window_len)
|
max_len_key = "sliding_window"
|
||||||
max_len_key = "sliding_window" \
|
derived_max_model_len = sliding_window
|
||||||
if sliding_window_len_min < derived_max_model_len else max_len_key
|
|
||||||
derived_max_model_len = min(derived_max_model_len,
|
|
||||||
sliding_window_len_min)
|
|
||||||
|
|
||||||
# Consider model_max_length in tokenizer_config
|
# Consider model_max_length in tokenizer_config
|
||||||
if tokenizer_config:
|
if tokenizer_config:
|
||||||
@ -3451,14 +3410,6 @@ def _get_and_verify_max_len(
|
|||||||
return int(max_model_len)
|
return int(max_model_len)
|
||||||
|
|
||||||
|
|
||||||
def get_min_sliding_window(
|
|
||||||
sliding_window: Union[int, list[Optional[int]]]) -> int:
|
|
||||||
if isinstance(sliding_window, list):
|
|
||||||
return min(s for s in sliding_window if s is not None)
|
|
||||||
|
|
||||||
return sliding_window
|
|
||||||
|
|
||||||
|
|
||||||
def get_served_model_name(model: str,
|
def get_served_model_name(model: str,
|
||||||
served_model_name: Optional[Union[str, list[str]]]):
|
served_model_name: Optional[Union[str, list[str]]]):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -39,6 +39,7 @@ from vllm.plugins import load_general_plugins
|
|||||||
from vllm.ray.lazy_utils import is_ray_initialized
|
from vllm.ray.lazy_utils import is_ray_initialized
|
||||||
from vllm.reasoning import ReasoningParserManager
|
from vllm.reasoning import ReasoningParserManager
|
||||||
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
|
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
|
||||||
|
from vllm.transformers_utils.config import is_interleaved
|
||||||
from vllm.transformers_utils.utils import check_gguf_file
|
from vllm.transformers_utils.utils import check_gguf_file
|
||||||
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
|
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
|
||||||
GiB_bytes, get_ip, is_in_ray_actor)
|
GiB_bytes, get_ip, is_in_ray_actor)
|
||||||
@ -1081,6 +1082,13 @@ class EngineArgs:
|
|||||||
"DualChunkFlashAttention is not supported on V1 engine. "
|
"DualChunkFlashAttention is not supported on V1 engine. "
|
||||||
"To run the model in V0 engine, try set 'VLLM_USE_V1=0'")
|
"To run the model in V0 engine, try set 'VLLM_USE_V1=0'")
|
||||||
|
|
||||||
|
sliding_window: Optional[int] = None
|
||||||
|
if not is_interleaved(model_config.hf_text_config):
|
||||||
|
# Only set CacheConfig.sliding_window if the model is all sliding
|
||||||
|
# window. Otherwise CacheConfig.sliding_window will override the
|
||||||
|
# global layers in interleaved sliding window models.
|
||||||
|
sliding_window = model_config.get_sliding_window()
|
||||||
|
|
||||||
cache_config = CacheConfig(
|
cache_config = CacheConfig(
|
||||||
block_size=self.block_size,
|
block_size=self.block_size,
|
||||||
gpu_memory_utilization=self.gpu_memory_utilization,
|
gpu_memory_utilization=self.gpu_memory_utilization,
|
||||||
@ -1088,7 +1096,7 @@ class EngineArgs:
|
|||||||
cache_dtype=self.kv_cache_dtype,
|
cache_dtype=self.kv_cache_dtype,
|
||||||
is_attention_free=model_config.is_attention_free,
|
is_attention_free=model_config.is_attention_free,
|
||||||
num_gpu_blocks_override=self.num_gpu_blocks_override,
|
num_gpu_blocks_override=self.num_gpu_blocks_override,
|
||||||
sliding_window=model_config.get_sliding_window(),
|
sliding_window=sliding_window,
|
||||||
enable_prefix_caching=self.enable_prefix_caching,
|
enable_prefix_caching=self.enable_prefix_caching,
|
||||||
prefix_caching_hash_algo=self.prefix_caching_hash_algo,
|
prefix_caching_hash_algo=self.prefix_caching_hash_algo,
|
||||||
cpu_offload_gb=self.cpu_offload_gb,
|
cpu_offload_gb=self.cpu_offload_gb,
|
||||||
|
|||||||
@ -182,21 +182,13 @@ class CohereAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Model v2 has interleaved sliding windows, v1 does not
|
# Model v2 has interleaved sliding windows, v1 does not
|
||||||
interleaved_sliding_window = getattr(config,
|
self.v1 = isinstance(config, CohereConfig)
|
||||||
"interleaved_sliding_window",
|
|
||||||
None)
|
|
||||||
self.v1 = interleaved_sliding_window is None
|
|
||||||
|
|
||||||
|
self.sliding_window = None
|
||||||
|
if not self.v1:
|
||||||
layer_idx = extract_layer_index(prefix)
|
layer_idx = extract_layer_index(prefix)
|
||||||
layer_has_sliding_window = (
|
if config.layer_types[layer_idx] == "sliding_attention":
|
||||||
getattr(config, "sliding_window_pattern", False) and
|
self.sliding_window = config.sliding_window
|
||||||
(layer_idx + 1) % self.config.sliding_window_pattern
|
|
||||||
!= 0) or (getattr(config, "layer_types", False)
|
|
||||||
and config.layer_types[layer_idx] == "sliding_attention")
|
|
||||||
|
|
||||||
self.sliding_window = (interleaved_sliding_window
|
|
||||||
or config.sliding_window
|
|
||||||
if layer_has_sliding_window else None)
|
|
||||||
|
|
||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
|
|||||||
@ -159,25 +159,12 @@ class Exaone4Attention(nn.Module):
|
|||||||
if quant_config is not None and quant_config.get_name() == "gguf":
|
if quant_config is not None and quant_config.get_name() == "gguf":
|
||||||
is_neox_style = False
|
is_neox_style = False
|
||||||
|
|
||||||
self.apply_all_layers = False # apply rotary embeddings to every layer.
|
|
||||||
layer_idx = extract_layer_index(prefix)
|
layer_idx = extract_layer_index(prefix)
|
||||||
interleaved_sliding_window = getattr(config,
|
is_sliding = config.layer_types[layer_idx] == "sliding_attention"
|
||||||
"interleaved_sliding_window",
|
self.sliding_window = config.sliding_window if is_sliding else None
|
||||||
4096)
|
|
||||||
sliding_window_pattern = getattr(config, "sliding_window_pattern",
|
|
||||||
"LLLG")
|
|
||||||
|
|
||||||
if sliding_window_pattern:
|
# apply rotary embeddings to every layer
|
||||||
layer_has_sliding_window = (
|
self.apply_all_layers = not is_sliding
|
||||||
layer_idx + 1) % sliding_window_pattern.__len__() != 0
|
|
||||||
else:
|
|
||||||
layer_has_sliding_window = False
|
|
||||||
self.apply_all_layers = True
|
|
||||||
|
|
||||||
if layer_has_sliding_window:
|
|
||||||
self.sliding_window = interleaved_sliding_window
|
|
||||||
else:
|
|
||||||
self.sliding_window = None
|
|
||||||
|
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
|
|||||||
@ -144,13 +144,10 @@ class Gemma2Attention(nn.Module):
|
|||||||
is_neox_style=True,
|
is_neox_style=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# reference:
|
|
||||||
# https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa
|
|
||||||
layer_idx = extract_layer_index(prefix)
|
layer_idx = extract_layer_index(prefix)
|
||||||
use_sliding_window = (layer_idx % 2 == 0 and getattr(
|
is_sliding = config.layer_types[layer_idx] == "sliding_attention"
|
||||||
config, "interleaved_sliding_window", None) is not None)
|
sliding_window = config.sliding_window if is_sliding else None
|
||||||
sliding_window = config.interleaved_sliding_window if \
|
|
||||||
use_sliding_window else None
|
|
||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
|
|||||||
@ -146,25 +146,19 @@ class Gemma3Attention(nn.Module):
|
|||||||
self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||||
self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
# TODO(woosuk): Add reference to the original HF implementation.
|
|
||||||
layer_idx = extract_layer_index(prefix)
|
layer_idx = extract_layer_index(prefix)
|
||||||
self.is_sliding = (getattr(
|
self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
|
||||||
config, "interleaved_sliding_window", None) is not None and (bool(
|
sliding_window = config.sliding_window if self.is_sliding else None
|
||||||
(layer_idx + 1) % config.sliding_window_pattern))) or (
|
|
||||||
getattr(config, "layer_types", None) is not None
|
|
||||||
and config.layer_types[layer_idx] == "sliding_attention")
|
|
||||||
# Initialize the rotary embedding.
|
# Initialize the rotary embedding.
|
||||||
if self.is_sliding:
|
if self.is_sliding:
|
||||||
# Local attention. Override the values in config.json.
|
# Local attention. Override the values in config.json.
|
||||||
self.rope_theta = config.rope_local_base_freq
|
self.rope_theta = config.rope_local_base_freq
|
||||||
self.rope_scaling = {"rope_type": "default"}
|
self.rope_scaling = {"rope_type": "default"}
|
||||||
self.sliding_window = (config.interleaved_sliding_window
|
|
||||||
or config.sliding_window)
|
|
||||||
else:
|
else:
|
||||||
# Global attention. Use the values in config.json.
|
# Global attention. Use the values in config.json.
|
||||||
self.rope_theta = config.rope_theta
|
self.rope_theta = config.rope_theta
|
||||||
self.rope_scaling = config.rope_scaling
|
self.rope_scaling = config.rope_scaling
|
||||||
self.sliding_window = None
|
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
rotary_dim=self.head_dim,
|
rotary_dim=self.head_dim,
|
||||||
@ -182,7 +176,7 @@ class Gemma3Attention(nn.Module):
|
|||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
logits_soft_cap=attn_logits_soft_cap,
|
logits_soft_cap=attn_logits_soft_cap,
|
||||||
per_layer_sliding_window=self.sliding_window,
|
per_layer_sliding_window=sliding_window,
|
||||||
prefix=f"{prefix}.attn")
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@ -502,8 +502,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.multimodal_config = multimodal_config
|
self.multimodal_config = multimodal_config
|
||||||
self.sliding_window = getattr(config.text_config,
|
|
||||||
"interleaved_sliding_window", None)
|
|
||||||
|
|
||||||
self.vision_tower = SiglipVisionModel(config.vision_config,
|
self.vision_tower = SiglipVisionModel(config.vision_config,
|
||||||
quant_config,
|
quant_config,
|
||||||
@ -690,11 +688,11 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
|
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
|
||||||
global_attn_masks.append(global_attn_mask)
|
global_attn_masks.append(global_attn_mask)
|
||||||
|
|
||||||
if self.sliding_window is not None:
|
if (sliding_window := self.config.sliding_window) is not None:
|
||||||
# Create a local causal mask with sliding window (1024).
|
# Create a local causal mask with sliding window (1024).
|
||||||
local_attn_mask = torch.ones_like(global_attn_mask)
|
local_attn_mask = torch.ones_like(global_attn_mask)
|
||||||
local_attn_mask = torch.tril(local_attn_mask,
|
local_attn_mask = torch.tril(local_attn_mask,
|
||||||
diagonal=-self.sliding_window)
|
diagonal=-sliding_window)
|
||||||
local_attn_mask = torch.where(local_attn_mask == 0,
|
local_attn_mask = torch.where(local_attn_mask == 0,
|
||||||
global_attn_mask, float("-inf"))
|
global_attn_mask, float("-inf"))
|
||||||
local_attn_masks.append(local_attn_mask)
|
local_attn_masks.append(local_attn_mask)
|
||||||
|
|||||||
@ -313,17 +313,16 @@ class Gemma3nAttention(nn.Module):
|
|||||||
has_weight=False)
|
has_weight=False)
|
||||||
|
|
||||||
layer_idx = extract_layer_index(prefix)
|
layer_idx = extract_layer_index(prefix)
|
||||||
|
is_sliding = config.layer_types[layer_idx] == "sliding_attention"
|
||||||
|
self.sliding_window = config.sliding_window if is_sliding else None
|
||||||
|
|
||||||
is_sliding_window = (
|
# Initialize the rotary embedding.
|
||||||
getattr(config, "interleaved_sliding_window", None) is not None
|
if is_sliding:
|
||||||
and config.layer_types[layer_idx] == "sliding_attention")
|
# Local attention. Override the values in config.json.
|
||||||
|
|
||||||
if is_sliding_window:
|
|
||||||
self.sliding_window = config.interleaved_sliding_window
|
|
||||||
rope_theta = config.rope_local_base_freq
|
rope_theta = config.rope_local_base_freq
|
||||||
rope_scaling = {"rope_type": "default"}
|
rope_scaling = {"rope_type": "default"}
|
||||||
else:
|
else:
|
||||||
self.sliding_window = None
|
# Global attention. Use the values in config.json.
|
||||||
rope_theta = config.rope_theta
|
rope_theta = config.rope_theta
|
||||||
rope_scaling = config.rope_scaling
|
rope_scaling = config.rope_scaling
|
||||||
|
|
||||||
|
|||||||
@ -248,9 +248,7 @@ class GritLM(LlamaForCausalLM, SupportsV0Only):
|
|||||||
|
|
||||||
vllm_config.cache_config.sliding_window = None
|
vllm_config.cache_config.sliding_window = None
|
||||||
|
|
||||||
for attr in ("sliding_window", "interleaved_sliding_window"):
|
hf_config.sliding_window = None
|
||||||
if hasattr(hf_config, attr):
|
|
||||||
delattr(hf_config, attr)
|
|
||||||
|
|
||||||
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@ -167,18 +167,11 @@ class LlamaAttention(nn.Module):
|
|||||||
rope_scaling=rope_scaling,
|
rope_scaling=rope_scaling,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config)
|
||||||
|
|
||||||
if hasattr(config, "interleaved_sliding_window"):
|
|
||||||
interleaved_sliding_window = config.interleaved_sliding_window
|
|
||||||
if isinstance(interleaved_sliding_window, int):
|
|
||||||
sliding_window = interleaved_sliding_window
|
|
||||||
elif isinstance(interleaved_sliding_window, list):
|
|
||||||
sw_idx = layer_idx % len(interleaved_sliding_window)
|
|
||||||
sliding_window = interleaved_sliding_window[sw_idx]
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"{type(interleaved_sliding_window)} is not supported.")
|
|
||||||
else:
|
|
||||||
sliding_window = None
|
sliding_window = None
|
||||||
|
if layer_types := getattr(config, "layer_types", None):
|
||||||
|
is_sliding = layer_types[layer_idx] == "sliding_attention"
|
||||||
|
if is_sliding:
|
||||||
|
sliding_window = config.sliding_window
|
||||||
|
|
||||||
self.attn = Attention(
|
self.attn = Attention(
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
|
|||||||
@ -116,13 +116,8 @@ class SambaYAttention(nn.Module):
|
|||||||
self.Wqkv = nn.Linear(self.hidden_size, op_size, bias=True)
|
self.Wqkv = nn.Linear(self.hidden_size, op_size, bias=True)
|
||||||
|
|
||||||
# disable sliding window for the second half of the model
|
# disable sliding window for the second half of the model
|
||||||
sliding_window = config.interleaved_sliding_window[layer_idx]
|
is_sliding = config.layer_types[layer_idx] == "sliding_attention"
|
||||||
if layer_idx >= config.num_hidden_layers // 2:
|
sliding_window = config.sliding_window if is_sliding else None
|
||||||
assert sliding_window is None, \
|
|
||||||
"sliding_window must be none for the second decoder"
|
|
||||||
else:
|
|
||||||
assert sliding_window is not None, \
|
|
||||||
"sliding_window must be set for the first decoder"
|
|
||||||
|
|
||||||
assert self.num_heads % 2 == 0, 'num_heads should be even'
|
assert self.num_heads % 2 == 0, 'num_heads should be even'
|
||||||
assert self.num_key_value_heads % 2 == 0, 'num_heads should be even'
|
assert self.num_key_value_heads % 2 == 0, 'num_heads should be even'
|
||||||
|
|||||||
@ -49,6 +49,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||||||
default_weight_loader, maybe_remap_kv_scale_name)
|
default_weight_loader, maybe_remap_kv_scale_name)
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.transformers_utils.config import is_interleaved
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA, SupportsPP
|
from .interfaces import SupportsLoRA, SupportsPP
|
||||||
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
|
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
|
||||||
@ -285,8 +286,7 @@ class Qwen2Model(nn.Module):
|
|||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
|
|
||||||
# TODO (@robertgshaw2): see if this can be moved out
|
# TODO (@robertgshaw2): see if this can be moved out
|
||||||
if (cache_config.sliding_window is not None
|
if is_interleaved(vllm_config.model_config.hf_text_config):
|
||||||
and hasattr(config, "max_window_layers")):
|
|
||||||
assert config.max_window_layers == config.num_hidden_layers, (
|
assert config.max_window_layers == config.num_hidden_layers, (
|
||||||
"Sliding window for some but all layers is not supported. "
|
"Sliding window for some but all layers is not supported. "
|
||||||
"This model uses sliding window but `max_window_layers` = {} "
|
"This model uses sliding window but `max_window_layers` = {} "
|
||||||
|
|||||||
@ -16,7 +16,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Wrapper around `transformers` models"""
|
"""Wrapper around `transformers` models"""
|
||||||
from collections.abc import Iterable, Mapping
|
from collections.abc import Iterable, Mapping
|
||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import contextmanager
|
||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
import regex as re
|
import regex as re
|
||||||
@ -382,33 +382,6 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ConfigOverride:
|
|
||||||
"""Context manager to temporarily override config attributes."""
|
|
||||||
|
|
||||||
def __init__(self, config: PretrainedConfig, **kwargs):
|
|
||||||
self.config = config
|
|
||||||
self.kwargs = kwargs
|
|
||||||
self.kwargs_original = {}
|
|
||||||
self.kwargs_delete = set()
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
"""Override config attributes."""
|
|
||||||
for key, value in self.kwargs.items():
|
|
||||||
if not hasattr(self.config, key):
|
|
||||||
self.kwargs_delete.add(key)
|
|
||||||
self.kwargs_original[key] = getattr(self.config, key, None)
|
|
||||||
setattr(self.config, key, value)
|
|
||||||
return self.config
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
|
||||||
"""Restore original config attributes."""
|
|
||||||
for key, value in self.kwargs_original.items():
|
|
||||||
if key in self.kwargs_delete:
|
|
||||||
delattr(self.config, key)
|
|
||||||
else:
|
|
||||||
setattr(self.config, key, value)
|
|
||||||
|
|
||||||
|
|
||||||
class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||||
embedding_padding_modules = ["lm_head"]
|
embedding_padding_modules = ["lm_head"]
|
||||||
embedding_modules = ["embed_tokens"
|
embedding_modules = ["embed_tokens"
|
||||||
@ -434,21 +407,11 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
# To be updated in child classes for use in `load_weights`
|
# To be updated in child classes for use in `load_weights`
|
||||||
self.skip_prefixes: Optional[list[str]] = None
|
self.skip_prefixes: Optional[list[str]] = None
|
||||||
|
|
||||||
# vLLM handles interleaved sliding window attention by creating a new
|
|
||||||
# interleaved_sliding_window attribute and deleting the sliding_window
|
|
||||||
# attribute. This breaks the constructors in Transformers so we
|
|
||||||
# temporarily add the attribute back to construct the model.
|
|
||||||
config_override = nullcontext()
|
|
||||||
if hasattr(self.config, "interleaved_sliding_window"):
|
|
||||||
config_override = ConfigOverride(
|
|
||||||
self.config,
|
|
||||||
sliding_window=self.config.interleaved_sliding_window)
|
|
||||||
|
|
||||||
# Set correct attn and init on "meta" to delay allocating GPU tensors
|
# Set correct attn and init on "meta" to delay allocating GPU tensors
|
||||||
# TODO: @raushan, use the public `model.set_attn_implementation()`
|
# TODO: @raushan, use the public `model.set_attn_implementation()`
|
||||||
# method once its checks are fixed in Transformers.
|
# method once its checks are fixed in Transformers.
|
||||||
self.text_config._attn_implementation = "vllm"
|
self.text_config._attn_implementation = "vllm"
|
||||||
with init_on_device_without_buffers("meta"), config_override:
|
with init_on_device_without_buffers("meta"):
|
||||||
self.model: PreTrainedModel = AutoModel.from_config(
|
self.model: PreTrainedModel = AutoModel.from_config(
|
||||||
self.config,
|
self.config,
|
||||||
torch_dtype=self.model_config.dtype,
|
torch_dtype=self.model_config.dtype,
|
||||||
@ -575,11 +538,10 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
attention_instances = {}
|
attention_instances = {}
|
||||||
for i in range(start, end):
|
for i in range(start, end):
|
||||||
# Handle interleaved sliding window attention
|
# Handle interleaved sliding window attention
|
||||||
sliding_window = None
|
per_layer_sliding_window = None
|
||||||
if (hasattr(self.config, "interleaved_sliding_window")
|
if (hasattr(self.config, "layer_types")
|
||||||
and hasattr(self.config, "sliding_window_pattern")
|
and self.config.layer_types[i] == "sliding_attention"):
|
||||||
and ((i + 1) % self.config.sliding_window_pattern > 0)):
|
per_layer_sliding_window = self.config.sliding_window
|
||||||
sliding_window = self.config.interleaved_sliding_window
|
|
||||||
|
|
||||||
attention_instances[i] = Attention(
|
attention_instances[i] = Attention(
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
@ -590,7 +552,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
num_kv_heads=num_kv_heads,
|
num_kv_heads=num_kv_heads,
|
||||||
cache_config=self.cache_config,
|
cache_config=self.cache_config,
|
||||||
quant_config=self.quant_config,
|
quant_config=self.quant_config,
|
||||||
per_layer_sliding_window=sliding_window,
|
per_layer_sliding_window=per_layer_sliding_window,
|
||||||
prefix=f"{i}.attn")
|
prefix=f"{i}.attn")
|
||||||
return attention_instances
|
return attention_instances
|
||||||
|
|
||||||
|
|||||||
@ -280,6 +280,17 @@ def is_encoder_decoder(config: PretrainedConfig) -> bool:
|
|||||||
return getattr(config, "is_encoder_decoder", False)
|
return getattr(config, "is_encoder_decoder", False)
|
||||||
|
|
||||||
|
|
||||||
|
def is_interleaved(config: PretrainedConfig) -> bool:
|
||||||
|
"""
|
||||||
|
Detect if the model with this config is used with interleaved attention.
|
||||||
|
"""
|
||||||
|
text_config = config.get_text_config()
|
||||||
|
if layer_types := getattr(text_config, "layer_types", None):
|
||||||
|
interleaved_types = {"full_attention", "sliding_attention"}
|
||||||
|
return interleaved_types.issubset(layer_types)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig:
|
def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig:
|
||||||
"""Remap config attributes to match the expected names."""
|
"""Remap config attributes to match the expected names."""
|
||||||
for old_attr, new_attr in _CONFIG_ATTRS_MAPPING.items():
|
for old_attr, new_attr in _CONFIG_ATTRS_MAPPING.items():
|
||||||
@ -423,6 +434,23 @@ def get_config(
|
|||||||
raise e
|
raise e
|
||||||
config = _maybe_remap_hf_config_attrs(config)
|
config = _maybe_remap_hf_config_attrs(config)
|
||||||
|
|
||||||
|
# Phi4Flash misuses this config as list[int]. Convert it to int and add
|
||||||
|
# the layer_types list[str] to make it HF compatible
|
||||||
|
if (config.model_type == "phi4flash"):
|
||||||
|
# TODO: Remove after the following PR is merged:
|
||||||
|
# https://huggingface.co/microsoft/Phi-4-mini-flash-reasoning/discussions/6
|
||||||
|
if not hasattr(config, "layer_types"):
|
||||||
|
config.layer_types = [
|
||||||
|
"sliding_attention" if i < config.num_hidden_layers // 2
|
||||||
|
and i % 2 == 1 else "full_attention"
|
||||||
|
for i in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
# TODO: Remove after the following PR is merged:
|
||||||
|
# https://huggingface.co/microsoft/Phi-4-mini-flash-reasoning/discussions/7
|
||||||
|
if isinstance(config.sliding_window, list):
|
||||||
|
config.sliding_window = next(
|
||||||
|
filter(None, config.sliding_window), None)
|
||||||
|
|
||||||
elif config_format == ConfigFormat.MISTRAL:
|
elif config_format == ConfigFormat.MISTRAL:
|
||||||
# This function loads a params.json config which
|
# This function loads a params.json config which
|
||||||
# should be used when loading models in mistral format
|
# should be used when loading models in mistral format
|
||||||
@ -434,6 +462,18 @@ def get_config(
|
|||||||
config_dict["max_position_embeddings"] = max_position_embeddings
|
config_dict["max_position_embeddings"] = max_position_embeddings
|
||||||
|
|
||||||
config = adapt_config_dict(config_dict)
|
config = adapt_config_dict(config_dict)
|
||||||
|
|
||||||
|
# Mistral configs may define sliding_window as list[int]. Convert it
|
||||||
|
# to int and add the layer_types list[str] to make it HF compatible
|
||||||
|
if ((sliding_window := getattr(config, "sliding_window", None))
|
||||||
|
and isinstance(sliding_window, list)):
|
||||||
|
pattern_repeats = config.num_hidden_layers // len(sliding_window)
|
||||||
|
layer_types = sliding_window * pattern_repeats
|
||||||
|
config.layer_types = [
|
||||||
|
"full_attention" if layer_type is None else "sliding_attention"
|
||||||
|
for layer_type in layer_types
|
||||||
|
]
|
||||||
|
config.sliding_window = next(filter(None, sliding_window), None)
|
||||||
else:
|
else:
|
||||||
supported_formats = [
|
supported_formats = [
|
||||||
fmt.value for fmt in ConfigFormat if fmt != ConfigFormat.AUTO
|
fmt.value for fmt in ConfigFormat if fmt != ConfigFormat.AUTO
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user