[core] gemma2 full context length support (#10584)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-11-22 20:13:54 -08:00 committed by GitHub
parent 978b39744b
commit 4aba6e3d1a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 54 additions and 23 deletions

View File

@ -14,11 +14,12 @@ from vllm import LLM
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
from ..conftest import VllmRunner
from ..models.utils import check_outputs_equal from ..models.utils import check_outputs_equal
from ..utils import multi_gpu_test from ..utils import multi_gpu_test
MODELS = [ MODELS = [
"facebook/opt-125m", "google/gemma-2-2b-it",
"meta-llama/Llama-3.2-1B", "meta-llama/Llama-3.2-1B",
] ]
@ -42,8 +43,6 @@ def test_vllm_gc_ed():
@pytest.mark.parametrize("enforce_eager", [False, True]) @pytest.mark.parametrize("enforce_eager", [False, True])
def test_models( def test_models(
hf_runner, hf_runner,
vllm_runner,
example_prompts,
model: str, model: str,
backend: str, backend: str,
dtype: str, dtype: str,
@ -54,15 +53,27 @@ def test_models(
if backend == "FLASHINFER" and current_platform.is_rocm(): if backend == "FLASHINFER" and current_platform.is_rocm():
pytest.skip("Flashinfer does not support ROCm/HIP.") pytest.skip("Flashinfer does not support ROCm/HIP.")
if backend == "XFORMERS" and model == "google/gemma-2-2b-it":
pytest.skip(
"XFORMERS does not support gemma2 with full context length.")
os.environ["VLLM_ATTENTION_BACKEND"] = backend os.environ["VLLM_ATTENTION_BACKEND"] = backend
# 5042 tokens for gemma2
# gemma2 has alternating sliding window size of 4096
# we need a prompt with more than 4096 tokens to test the sliding window
prompt = "The following numbers of the sequence " + ", ".join(
str(i) for i in range(1024)) + " are:"
example_prompts = [prompt]
with hf_runner(model, dtype=dtype) as hf_model: with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
with vllm_runner(model, with VllmRunner(model,
dtype=dtype, max_model_len=8192,
enforce_eager=enforce_eager, dtype=dtype,
gpu_memory_utilization=0.7) as vllm_model: enforce_eager=enforce_eager,
gpu_memory_utilization=0.7) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
check_outputs_equal( check_outputs_equal(

View File

@ -40,18 +40,26 @@ class Attention(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
blocksparse_params: Optional[Dict[str, Any]] = None, blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None, logits_soft_cap: Optional[float] = None,
per_layer_sliding_window: Optional[int] = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
if per_layer_sliding_window is not None:
# per-layer sliding window
sliding_window = per_layer_sliding_window
elif cache_config is not None:
# model-level sliding window
sliding_window = cache_config.sliding_window
else:
sliding_window = None
if cache_config is not None: if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size block_size = cache_config.block_size
sliding_window = cache_config.sliding_window
is_attention_free = cache_config.is_attention_free is_attention_free = cache_config.is_attention_free
else: else:
kv_cache_dtype = "auto" kv_cache_dtype = "auto"
block_size = 16 block_size = 16
sliding_window = None
is_attention_free = False is_attention_free = False
if num_kv_heads is None: if num_kv_heads is None:
num_kv_heads = num_heads num_kv_heads = num_heads

View File

@ -233,15 +233,26 @@ class ModelConfig:
(self.hf_text_config.model_type in ["gemma2"])) (self.hf_text_config.model_type in ["gemma2"]))
if (not self.disable_sliding_window and has_interleaved_attention): if (not self.disable_sliding_window and has_interleaved_attention):
sliding_window_len_min = get_min_sliding_window( if envs.VLLM_ATTENTION_BACKEND == "XFORMERS":
self.hf_text_config.sliding_window) sliding_window_len_min = get_min_sliding_window(
self.hf_text_config.sliding_window)
print_warning_once( print_warning_once(
f"{self.hf_text_config.model_type} has interleaved attention, " f"{self.hf_text_config.model_type} has interleaved "
"which is currently not supported by vLLM. Disabling sliding " "attention, which is currently not supported by the "
"window and capping the max length to the sliding window size " "XFORMERS backend. Disabling sliding window and capping "
f"({sliding_window_len_min}).") "the max length to the sliding window size "
self.disable_sliding_window = True f"({sliding_window_len_min}).")
self.disable_sliding_window = True
else:
# for a model with interleaved attention,
# the scheduler and the model treat it as full attention
# (i.e., not dropping any tokens outside the window).
# only the attention layer itself is aware of the sliding
# window, and use the window size to compute the attention.
self.hf_text_config.interleaved_sliding_window = sliding_window
delattr(self.hf_text_config, "sliding_window")
sliding_window = None
self.max_model_len = _get_and_verify_max_len( self.max_model_len = _get_and_verify_max_len(
hf_config=self.hf_text_config, hf_config=self.hf_text_config,

View File

@ -143,12 +143,12 @@ class Gemma2Attention(nn.Module):
is_neox_style=True, is_neox_style=True,
) )
# FIXME(woosuk): While Gemma 2 uses sliding window attention for every # reference:
# odd layer, vLLM currently ignores it and uses global attention for # https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa
# all layers. use_sliding_window = (layer_idx % 2 == 0 and
use_sliding_window = (layer_idx % 2 == 1 config.interleaved_sliding_window is not None)
and config.sliding_window is not None) sliding_window = config.interleaved_sliding_window if \
del use_sliding_window # Unused. use_sliding_window else None
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
@ -156,6 +156,7 @@ class Gemma2Attention(nn.Module):
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
logits_soft_cap=attn_logits_soft_cap, logits_soft_cap=attn_logits_soft_cap,
per_layer_sliding_window=sliding_window,
prefix=f"{prefix}.attn") prefix=f"{prefix}.attn")
def forward( def forward(