mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:35:17 +08:00
[core] gemma2 full context length support (#10584)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
978b39744b
commit
4aba6e3d1a
@ -14,11 +14,12 @@ from vllm import LLM
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||||
|
|
||||||
|
from ..conftest import VllmRunner
|
||||||
from ..models.utils import check_outputs_equal
|
from ..models.utils import check_outputs_equal
|
||||||
from ..utils import multi_gpu_test
|
from ..utils import multi_gpu_test
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
"facebook/opt-125m",
|
"google/gemma-2-2b-it",
|
||||||
"meta-llama/Llama-3.2-1B",
|
"meta-llama/Llama-3.2-1B",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -42,8 +43,6 @@ def test_vllm_gc_ed():
|
|||||||
@pytest.mark.parametrize("enforce_eager", [False, True])
|
@pytest.mark.parametrize("enforce_eager", [False, True])
|
||||||
def test_models(
|
def test_models(
|
||||||
hf_runner,
|
hf_runner,
|
||||||
vllm_runner,
|
|
||||||
example_prompts,
|
|
||||||
model: str,
|
model: str,
|
||||||
backend: str,
|
backend: str,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
@ -54,15 +53,27 @@ def test_models(
|
|||||||
if backend == "FLASHINFER" and current_platform.is_rocm():
|
if backend == "FLASHINFER" and current_platform.is_rocm():
|
||||||
pytest.skip("Flashinfer does not support ROCm/HIP.")
|
pytest.skip("Flashinfer does not support ROCm/HIP.")
|
||||||
|
|
||||||
|
if backend == "XFORMERS" and model == "google/gemma-2-2b-it":
|
||||||
|
pytest.skip(
|
||||||
|
"XFORMERS does not support gemma2 with full context length.")
|
||||||
|
|
||||||
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
||||||
|
|
||||||
|
# 5042 tokens for gemma2
|
||||||
|
# gemma2 has alternating sliding window size of 4096
|
||||||
|
# we need a prompt with more than 4096 tokens to test the sliding window
|
||||||
|
prompt = "The following numbers of the sequence " + ", ".join(
|
||||||
|
str(i) for i in range(1024)) + " are:"
|
||||||
|
example_prompts = [prompt]
|
||||||
|
|
||||||
with hf_runner(model, dtype=dtype) as hf_model:
|
with hf_runner(model, dtype=dtype) as hf_model:
|
||||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
|
||||||
with vllm_runner(model,
|
with VllmRunner(model,
|
||||||
dtype=dtype,
|
max_model_len=8192,
|
||||||
enforce_eager=enforce_eager,
|
dtype=dtype,
|
||||||
gpu_memory_utilization=0.7) as vllm_model:
|
enforce_eager=enforce_eager,
|
||||||
|
gpu_memory_utilization=0.7) as vllm_model:
|
||||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
|
||||||
check_outputs_equal(
|
check_outputs_equal(
|
||||||
|
|||||||
@ -40,18 +40,26 @@ class Attention(nn.Module):
|
|||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
|
per_layer_sliding_window: Optional[int] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if per_layer_sliding_window is not None:
|
||||||
|
# per-layer sliding window
|
||||||
|
sliding_window = per_layer_sliding_window
|
||||||
|
elif cache_config is not None:
|
||||||
|
# model-level sliding window
|
||||||
|
sliding_window = cache_config.sliding_window
|
||||||
|
else:
|
||||||
|
sliding_window = None
|
||||||
|
|
||||||
if cache_config is not None:
|
if cache_config is not None:
|
||||||
kv_cache_dtype = cache_config.cache_dtype
|
kv_cache_dtype = cache_config.cache_dtype
|
||||||
block_size = cache_config.block_size
|
block_size = cache_config.block_size
|
||||||
sliding_window = cache_config.sliding_window
|
|
||||||
is_attention_free = cache_config.is_attention_free
|
is_attention_free = cache_config.is_attention_free
|
||||||
else:
|
else:
|
||||||
kv_cache_dtype = "auto"
|
kv_cache_dtype = "auto"
|
||||||
block_size = 16
|
block_size = 16
|
||||||
sliding_window = None
|
|
||||||
is_attention_free = False
|
is_attention_free = False
|
||||||
if num_kv_heads is None:
|
if num_kv_heads is None:
|
||||||
num_kv_heads = num_heads
|
num_kv_heads = num_heads
|
||||||
|
|||||||
@ -233,15 +233,26 @@ class ModelConfig:
|
|||||||
(self.hf_text_config.model_type in ["gemma2"]))
|
(self.hf_text_config.model_type in ["gemma2"]))
|
||||||
|
|
||||||
if (not self.disable_sliding_window and has_interleaved_attention):
|
if (not self.disable_sliding_window and has_interleaved_attention):
|
||||||
sliding_window_len_min = get_min_sliding_window(
|
if envs.VLLM_ATTENTION_BACKEND == "XFORMERS":
|
||||||
self.hf_text_config.sliding_window)
|
sliding_window_len_min = get_min_sliding_window(
|
||||||
|
self.hf_text_config.sliding_window)
|
||||||
|
|
||||||
print_warning_once(
|
print_warning_once(
|
||||||
f"{self.hf_text_config.model_type} has interleaved attention, "
|
f"{self.hf_text_config.model_type} has interleaved "
|
||||||
"which is currently not supported by vLLM. Disabling sliding "
|
"attention, which is currently not supported by the "
|
||||||
"window and capping the max length to the sliding window size "
|
"XFORMERS backend. Disabling sliding window and capping "
|
||||||
f"({sliding_window_len_min}).")
|
"the max length to the sliding window size "
|
||||||
self.disable_sliding_window = True
|
f"({sliding_window_len_min}).")
|
||||||
|
self.disable_sliding_window = True
|
||||||
|
else:
|
||||||
|
# for a model with interleaved attention,
|
||||||
|
# the scheduler and the model treat it as full attention
|
||||||
|
# (i.e., not dropping any tokens outside the window).
|
||||||
|
# only the attention layer itself is aware of the sliding
|
||||||
|
# window, and use the window size to compute the attention.
|
||||||
|
self.hf_text_config.interleaved_sliding_window = sliding_window
|
||||||
|
delattr(self.hf_text_config, "sliding_window")
|
||||||
|
sliding_window = None
|
||||||
|
|
||||||
self.max_model_len = _get_and_verify_max_len(
|
self.max_model_len = _get_and_verify_max_len(
|
||||||
hf_config=self.hf_text_config,
|
hf_config=self.hf_text_config,
|
||||||
|
|||||||
@ -143,12 +143,12 @@ class Gemma2Attention(nn.Module):
|
|||||||
is_neox_style=True,
|
is_neox_style=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# FIXME(woosuk): While Gemma 2 uses sliding window attention for every
|
# reference:
|
||||||
# odd layer, vLLM currently ignores it and uses global attention for
|
# https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa
|
||||||
# all layers.
|
use_sliding_window = (layer_idx % 2 == 0 and
|
||||||
use_sliding_window = (layer_idx % 2 == 1
|
config.interleaved_sliding_window is not None)
|
||||||
and config.sliding_window is not None)
|
sliding_window = config.interleaved_sliding_window if \
|
||||||
del use_sliding_window # Unused.
|
use_sliding_window else None
|
||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
@ -156,6 +156,7 @@ class Gemma2Attention(nn.Module):
|
|||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
logits_soft_cap=attn_logits_soft_cap,
|
logits_soft_cap=attn_logits_soft_cap,
|
||||||
|
per_layer_sliding_window=sliding_window,
|
||||||
prefix=f"{prefix}.attn")
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user