mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:04:58 +08:00
171 lines
6.6 KiB
Python
171 lines
6.6 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import random
|
|
from typing import Optional, Union
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm import LLM, SamplingParams
|
|
from vllm.config import CompilationConfig, CompilationLevel
|
|
from vllm.distributed import cleanup_dist_env_and_memory
|
|
from vllm.forward_context import get_forward_context
|
|
from vllm.model_executor.models.gemma3n_mm import (
|
|
Gemma3nForConditionalGeneration)
|
|
from vllm.model_executor.models.registry import ModelRegistry
|
|
from vllm.model_executor.models.utils import extract_layer_index
|
|
from vllm.sequence import IntermediateTensors
|
|
|
|
from ...utils import fork_new_process_for_each_test
|
|
|
|
# global seed
|
|
SEED = 42
|
|
|
|
|
|
class TestGemma3nForConditionalGeneration(Gemma3nForConditionalGeneration):
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
**kwargs,
|
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
hidden_states = super().forward(input_ids, positions,
|
|
intermediate_tensors, inputs_embeds,
|
|
**kwargs)
|
|
attn_metadata = get_forward_context().attn_metadata
|
|
# attn_metadata is None during dummy runs
|
|
if (attn_metadata is not None
|
|
and self.language_model.cache_config.kv_sharing_fast_prefill):
|
|
assert isinstance(attn_metadata, dict) # true in V1
|
|
# Gemma3n-E2B has 30 layers, with last 20 layers being
|
|
# cross-decoder layers. Check attention metadata is correct
|
|
for layer_name, metadata in attn_metadata.items():
|
|
layer_idx = extract_layer_index(layer_name)
|
|
if layer_idx >= 20:
|
|
assert hasattr(metadata, 'logits_indices_padded')
|
|
assert hasattr(metadata, 'num_logits_indices')
|
|
else:
|
|
assert not hasattr(metadata, 'logits_indices_padded')
|
|
assert not hasattr(metadata, 'num_logits_indices')
|
|
|
|
# Last layer will be a KV sharing layer
|
|
layer_attn_metadata = attn_metadata[
|
|
self.language_model.model.layers[-1].self_attn.attn.layer_name]
|
|
logits_indices_padded = (layer_attn_metadata.logits_indices_padded)
|
|
assert logits_indices_padded is not None
|
|
num_logits_indices = layer_attn_metadata.num_logits_indices
|
|
assert num_logits_indices > 0
|
|
# Reset hidden states to random values and
|
|
# only set logits at logits_indices to valid values
|
|
# Because logits_indices are the only positions that are used
|
|
# for output token sampling, this still produces same outputs
|
|
logits_hs = hidden_states[logits_indices_padded]
|
|
hidden_states = torch.randn_like(hidden_states)
|
|
gen_indices = logits_indices_padded[:num_logits_indices]
|
|
hidden_states[gen_indices] = logits_hs[:num_logits_indices]
|
|
|
|
return hidden_states
|
|
|
|
|
|
@pytest.fixture
|
|
def test_prompts():
|
|
"""
|
|
Adapted from tests/v1/e2e/test_spec_decode.py
|
|
"""
|
|
prompt_types = ["repeat", "sentence"]
|
|
# Setting higher num prompts increases the chance of numerics mismatch
|
|
# due to matrix multiplication numerics depending on batch dimension
|
|
num_prompts = 10
|
|
prompts = []
|
|
|
|
random.seed(0)
|
|
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
|
|
|
|
for kind in random_prompt_type_choices:
|
|
word_choices = ["test", "temp", "hello", "where"]
|
|
word = random.choice(word_choices)
|
|
if kind == "repeat":
|
|
prompt = f"""please repeat the word '{word}' 10 times."""
|
|
elif kind == "sentence":
|
|
prompt = f"""please give a ten-word sentence that
|
|
uses the word {word} at least once."""
|
|
else:
|
|
raise ValueError(f"Unknown prompt type: {kind}")
|
|
prompts.append(prompt)
|
|
|
|
return prompts
|
|
|
|
|
|
def cleanup(llm: LLM, compilation_config: CompilationConfig):
|
|
# hacky: below lines are required to free up memory for the next test
|
|
# when setting VLLM_ENABLE_V1_MULTIPROCESSING=0, del llm is not sufficient
|
|
# TODO(sarckk): when enforce_eager=False, memory is not freed:
|
|
# find out why and re-enable test for enforce_eager=False case
|
|
llm_engine = llm.llm_engine.engine_core.engine_core
|
|
model_runner = llm_engine.model_executor.driver_worker.worker.model_runner
|
|
del model_runner.model
|
|
del model_runner.kv_caches
|
|
del compilation_config.static_forward_context
|
|
compilation_config.static_forward_context = {}
|
|
|
|
del llm
|
|
torch.cuda.empty_cache()
|
|
cleanup_dist_env_and_memory()
|
|
|
|
|
|
@fork_new_process_for_each_test
|
|
@pytest.mark.parametrize("enforce_eager", [True])
|
|
def test_kv_sharing_fast_prefill(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
enforce_eager: bool,
|
|
test_prompts: list[str],
|
|
):
|
|
ModelRegistry.register_model("Gemma3nForConditionalGeneration",
|
|
TestGemma3nForConditionalGeneration)
|
|
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
|
|
compilation_config = CompilationConfig(
|
|
# This allows vLLM compilation backend to handle allocating and
|
|
# managing buffers for cudagraph
|
|
cudagraph_copy_inputs=True,
|
|
level=CompilationLevel.PIECEWISE
|
|
if not enforce_eager else CompilationLevel.NO_COMPILATION)
|
|
|
|
with monkeypatch.context() as m:
|
|
m.setenv("VLLM_USE_V1", "1")
|
|
|
|
# Make scheduling deterministic for reproducibility
|
|
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
|
|
|
llm = LLM(
|
|
model="google/gemma-3n-E2B-it",
|
|
enforce_eager=enforce_eager,
|
|
compilation_config=compilation_config,
|
|
seed=SEED,
|
|
)
|
|
ref_responses = llm.generate(test_prompts, sampling_params)
|
|
|
|
cleanup(llm, compilation_config)
|
|
|
|
llm = LLM(model="google/gemma-3n-E2B-it",
|
|
enforce_eager=enforce_eager,
|
|
compilation_config=compilation_config,
|
|
seed=SEED,
|
|
kv_sharing_fast_prefill=True)
|
|
optimized_responses = llm.generate(test_prompts, sampling_params)
|
|
|
|
cleanup(llm, compilation_config)
|
|
|
|
misses = 0
|
|
|
|
for ref_response, optimized_response in zip(ref_responses,
|
|
optimized_responses):
|
|
if ref_response.outputs[0].text != optimized_response.outputs[
|
|
0].text:
|
|
misses += 1
|
|
|
|
assert misses == 0
|