mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:45:01 +08:00
Support mistral interleaved attn (#9414)
This commit is contained in:
parent
cf1d62a644
commit
415f76a9cb
@ -173,14 +173,20 @@ class ModelConfig:
|
||||
if self.enforce_eager is None:
|
||||
self.enforce_eager = False
|
||||
|
||||
if (not self.disable_sliding_window
|
||||
and self.hf_text_config.model_type == "gemma2"
|
||||
and self.hf_text_config.sliding_window is not None):
|
||||
sliding_window = getattr(self.hf_text_config, "sliding_window", None)
|
||||
has_interleaved_attention = (sliding_window is not None) and (
|
||||
isinstance(sliding_window, list) or
|
||||
(self.hf_text_config.model_type in ["gemma2"]))
|
||||
|
||||
if (not self.disable_sliding_window and has_interleaved_attention):
|
||||
sliding_window_len_min = get_min_sliding_window(
|
||||
self.hf_text_config.sliding_window)
|
||||
|
||||
print_warning_once(
|
||||
"Gemma 2 uses sliding window attention for every odd layer, "
|
||||
f"{self.hf_text_config.model_type} has interleaved attention, "
|
||||
"which is currently not supported by vLLM. Disabling sliding "
|
||||
"window and capping the max length to the sliding window size "
|
||||
f"({self.hf_text_config.sliding_window}).")
|
||||
f"({sliding_window_len_min}).")
|
||||
self.disable_sliding_window = True
|
||||
|
||||
self.max_model_len = _get_and_verify_max_len(
|
||||
@ -431,7 +437,8 @@ class ModelConfig:
|
||||
"pipeline parallelism currently. Disabling it.")
|
||||
self.use_async_output_proc = False
|
||||
|
||||
def get_hf_config_sliding_window(self) -> Optional[int]:
|
||||
def get_hf_config_sliding_window(
|
||||
self) -> Union[Optional[int], List[Optional[int]]]:
|
||||
"""Get the sliding window size, or None if disabled."""
|
||||
|
||||
# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
|
||||
@ -442,7 +449,7 @@ class ModelConfig:
|
||||
return None
|
||||
return getattr(self.hf_text_config, "sliding_window", None)
|
||||
|
||||
def get_sliding_window(self) -> Optional[int]:
|
||||
def get_sliding_window(self) -> Optional[Union[int, List[Optional[int]]]]:
|
||||
"""Get the sliding window size, or None if disabled.
|
||||
"""
|
||||
# If user disables sliding window, return None.
|
||||
@ -1689,7 +1696,7 @@ def _get_and_verify_max_len(
|
||||
hf_config: PretrainedConfig,
|
||||
max_model_len: Optional[int],
|
||||
disable_sliding_window: bool,
|
||||
sliding_window_len: Optional[int],
|
||||
sliding_window_len: Optional[Union[int, List[Optional[int]]]],
|
||||
spec_target_max_model_len: Optional[int] = None,
|
||||
) -> int:
|
||||
"""Get and verify the model's maximum length."""
|
||||
@ -1722,9 +1729,12 @@ def _get_and_verify_max_len(
|
||||
# If sliding window is manually disabled, max_length should be less
|
||||
# than the sliding window length in the model config.
|
||||
if disable_sliding_window and sliding_window_len is not None:
|
||||
|
||||
sliding_window_len_min = get_min_sliding_window(sliding_window_len)
|
||||
max_len_key = "sliding_window" \
|
||||
if sliding_window_len < derived_max_model_len else max_len_key
|
||||
derived_max_model_len = min(derived_max_model_len, sliding_window_len)
|
||||
if sliding_window_len_min < derived_max_model_len else max_len_key
|
||||
derived_max_model_len = min(derived_max_model_len,
|
||||
sliding_window_len_min)
|
||||
|
||||
# If none of the keys were found in the config, use a default and
|
||||
# log a warning.
|
||||
@ -1805,6 +1815,14 @@ def _get_and_verify_max_len(
|
||||
return int(max_model_len)
|
||||
|
||||
|
||||
def get_min_sliding_window(
|
||||
sliding_window: Union[int, List[Optional[int]]]) -> int:
|
||||
if isinstance(sliding_window, list):
|
||||
return min(s for s in sliding_window if s is not None)
|
||||
|
||||
return sliding_window
|
||||
|
||||
|
||||
def get_served_model_name(model: str,
|
||||
served_model_name: Optional[Union[str, List[str]]]):
|
||||
"""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user