mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:44:57 +08:00
[core] gemma2 full context length support (#10584)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
978b39744b
commit
4aba6e3d1a
@ -14,11 +14,12 @@ from vllm import LLM
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
|
||||
from ..conftest import VllmRunner
|
||||
from ..models.utils import check_outputs_equal
|
||||
from ..utils import multi_gpu_test
|
||||
|
||||
MODELS = [
|
||||
"facebook/opt-125m",
|
||||
"google/gemma-2-2b-it",
|
||||
"meta-llama/Llama-3.2-1B",
|
||||
]
|
||||
|
||||
@ -42,8 +43,6 @@ def test_vllm_gc_ed():
|
||||
@pytest.mark.parametrize("enforce_eager", [False, True])
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
backend: str,
|
||||
dtype: str,
|
||||
@ -54,15 +53,27 @@ def test_models(
|
||||
if backend == "FLASHINFER" and current_platform.is_rocm():
|
||||
pytest.skip("Flashinfer does not support ROCm/HIP.")
|
||||
|
||||
if backend == "XFORMERS" and model == "google/gemma-2-2b-it":
|
||||
pytest.skip(
|
||||
"XFORMERS does not support gemma2 with full context length.")
|
||||
|
||||
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
||||
|
||||
# 5042 tokens for gemma2
|
||||
# gemma2 has alternating sliding window size of 4096
|
||||
# we need a prompt with more than 4096 tokens to test the sliding window
|
||||
prompt = "The following numbers of the sequence " + ", ".join(
|
||||
str(i) for i in range(1024)) + " are:"
|
||||
example_prompts = [prompt]
|
||||
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
with vllm_runner(model,
|
||||
dtype=dtype,
|
||||
enforce_eager=enforce_eager,
|
||||
gpu_memory_utilization=0.7) as vllm_model:
|
||||
with VllmRunner(model,
|
||||
max_model_len=8192,
|
||||
dtype=dtype,
|
||||
enforce_eager=enforce_eager,
|
||||
gpu_memory_utilization=0.7) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
check_outputs_equal(
|
||||
|
||||
@ -40,18 +40,26 @@ class Attention(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
per_layer_sliding_window: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if per_layer_sliding_window is not None:
|
||||
# per-layer sliding window
|
||||
sliding_window = per_layer_sliding_window
|
||||
elif cache_config is not None:
|
||||
# model-level sliding window
|
||||
sliding_window = cache_config.sliding_window
|
||||
else:
|
||||
sliding_window = None
|
||||
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
sliding_window = cache_config.sliding_window
|
||||
is_attention_free = cache_config.is_attention_free
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
sliding_window = None
|
||||
is_attention_free = False
|
||||
if num_kv_heads is None:
|
||||
num_kv_heads = num_heads
|
||||
|
||||
@ -233,15 +233,26 @@ class ModelConfig:
|
||||
(self.hf_text_config.model_type in ["gemma2"]))
|
||||
|
||||
if (not self.disable_sliding_window and has_interleaved_attention):
|
||||
sliding_window_len_min = get_min_sliding_window(
|
||||
self.hf_text_config.sliding_window)
|
||||
if envs.VLLM_ATTENTION_BACKEND == "XFORMERS":
|
||||
sliding_window_len_min = get_min_sliding_window(
|
||||
self.hf_text_config.sliding_window)
|
||||
|
||||
print_warning_once(
|
||||
f"{self.hf_text_config.model_type} has interleaved attention, "
|
||||
"which is currently not supported by vLLM. Disabling sliding "
|
||||
"window and capping the max length to the sliding window size "
|
||||
f"({sliding_window_len_min}).")
|
||||
self.disable_sliding_window = True
|
||||
print_warning_once(
|
||||
f"{self.hf_text_config.model_type} has interleaved "
|
||||
"attention, which is currently not supported by the "
|
||||
"XFORMERS backend. Disabling sliding window and capping "
|
||||
"the max length to the sliding window size "
|
||||
f"({sliding_window_len_min}).")
|
||||
self.disable_sliding_window = True
|
||||
else:
|
||||
# for a model with interleaved attention,
|
||||
# the scheduler and the model treat it as full attention
|
||||
# (i.e., not dropping any tokens outside the window).
|
||||
# only the attention layer itself is aware of the sliding
|
||||
# window, and use the window size to compute the attention.
|
||||
self.hf_text_config.interleaved_sliding_window = sliding_window
|
||||
delattr(self.hf_text_config, "sliding_window")
|
||||
sliding_window = None
|
||||
|
||||
self.max_model_len = _get_and_verify_max_len(
|
||||
hf_config=self.hf_text_config,
|
||||
|
||||
@ -143,12 +143,12 @@ class Gemma2Attention(nn.Module):
|
||||
is_neox_style=True,
|
||||
)
|
||||
|
||||
# FIXME(woosuk): While Gemma 2 uses sliding window attention for every
|
||||
# odd layer, vLLM currently ignores it and uses global attention for
|
||||
# all layers.
|
||||
use_sliding_window = (layer_idx % 2 == 1
|
||||
and config.sliding_window is not None)
|
||||
del use_sliding_window # Unused.
|
||||
# reference:
|
||||
# https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa
|
||||
use_sliding_window = (layer_idx % 2 == 0 and
|
||||
config.interleaved_sliding_window is not None)
|
||||
sliding_window = config.interleaved_sliding_window if \
|
||||
use_sliding_window else None
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
@ -156,6 +156,7 @@ class Gemma2Attention(nn.Module):
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
logits_soft_cap=attn_logits_soft_cap,
|
||||
per_layer_sliding_window=sliding_window,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user