mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:35:17 +08:00
[CI] Fix tests/v1/e2e/test_kv_sharing_fast_prefill.py import on test (#22815)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
c9232d41f4
commit
12817a8ac7
@ -11,7 +11,8 @@ from vllm import LLM, SamplingParams
|
|||||||
from vllm.config import CompilationConfig, CompilationLevel
|
from vllm.config import CompilationConfig, CompilationLevel
|
||||||
from vllm.distributed import cleanup_dist_env_and_memory
|
from vllm.distributed import cleanup_dist_env_and_memory
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.models.gemma3n import Gemma3nForConditionalGeneration
|
from vllm.model_executor.models.gemma3n_mm import (
|
||||||
|
Gemma3nForConditionalGeneration)
|
||||||
from vllm.model_executor.models.registry import ModelRegistry
|
from vllm.model_executor.models.registry import ModelRegistry
|
||||||
from vllm.model_executor.models.utils import extract_layer_index
|
from vllm.model_executor.models.utils import extract_layer_index
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
@ -32,12 +33,13 @@ class TestGemma3nForConditionalGeneration(Gemma3nForConditionalGeneration):
|
|||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
hidden_states = super().forward(input_ids, positions,
|
||||||
inputs_embeds, **kwargs)
|
intermediate_tensors, inputs_embeds,
|
||||||
|
**kwargs)
|
||||||
attn_metadata = get_forward_context().attn_metadata
|
attn_metadata = get_forward_context().attn_metadata
|
||||||
# attn_metadata is None during dummy runs
|
# attn_metadata is None during dummy runs
|
||||||
if (attn_metadata is not None
|
if (attn_metadata is not None
|
||||||
and self.cache_config.kv_sharing_fast_prefill):
|
and self.language_model.cache_config.kv_sharing_fast_prefill):
|
||||||
assert isinstance(attn_metadata, dict) # true in V1
|
assert isinstance(attn_metadata, dict) # true in V1
|
||||||
# Gemma3n-E2B has 30 layers, with last 20 layers being
|
# Gemma3n-E2B has 30 layers, with last 20 layers being
|
||||||
# cross-decoder layers. Check attention metadata is correct
|
# cross-decoder layers. Check attention metadata is correct
|
||||||
@ -52,7 +54,7 @@ class TestGemma3nForConditionalGeneration(Gemma3nForConditionalGeneration):
|
|||||||
|
|
||||||
# Last layer will be a KV sharing layer
|
# Last layer will be a KV sharing layer
|
||||||
layer_attn_metadata = attn_metadata[
|
layer_attn_metadata = attn_metadata[
|
||||||
self.model.language_model.layers[-1].self_attn.attn.layer_name]
|
self.language_model.model.layers[-1].self_attn.attn.layer_name]
|
||||||
logits_indices_padded = (layer_attn_metadata.logits_indices_padded)
|
logits_indices_padded = (layer_attn_metadata.logits_indices_padded)
|
||||||
assert logits_indices_padded is not None
|
assert logits_indices_padded is not None
|
||||||
num_logits_indices = layer_attn_metadata.num_logits_indices
|
num_logits_indices = layer_attn_metadata.num_logits_indices
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user