diff --git a/tests/v1/e2e/test_kv_sharing_fast_prefill.py b/tests/v1/e2e/test_kv_sharing_fast_prefill.py index f5a7b9cc276b..d72e50e5196b 100644 --- a/tests/v1/e2e/test_kv_sharing_fast_prefill.py +++ b/tests/v1/e2e/test_kv_sharing_fast_prefill.py @@ -11,7 +11,8 @@ from vllm import LLM, SamplingParams from vllm.config import CompilationConfig, CompilationLevel from vllm.distributed import cleanup_dist_env_and_memory from vllm.forward_context import get_forward_context -from vllm.model_executor.models.gemma3n import Gemma3nForConditionalGeneration +from vllm.model_executor.models.gemma3n_mm import ( + Gemma3nForConditionalGeneration) from vllm.model_executor.models.registry import ModelRegistry from vllm.model_executor.models.utils import extract_layer_index from vllm.sequence import IntermediateTensors @@ -32,12 +33,13 @@ class TestGemma3nForConditionalGeneration(Gemma3nForConditionalGeneration): inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds, **kwargs) + hidden_states = super().forward(input_ids, positions, + intermediate_tensors, inputs_embeds, + **kwargs) attn_metadata = get_forward_context().attn_metadata # attn_metadata is None during dummy runs if (attn_metadata is not None - and self.cache_config.kv_sharing_fast_prefill): + and self.language_model.cache_config.kv_sharing_fast_prefill): assert isinstance(attn_metadata, dict) # true in V1 # Gemma3n-E2B has 30 layers, with last 20 layers being # cross-decoder layers. Check attention metadata is correct @@ -52,7 +54,7 @@ class TestGemma3nForConditionalGeneration(Gemma3nForConditionalGeneration): # Last layer will be a KV sharing layer layer_attn_metadata = attn_metadata[ - self.model.language_model.layers[-1].self_attn.attn.layer_name] + self.language_model.model.layers[-1].self_attn.attn.layer_name] logits_indices_padded = (layer_attn_metadata.logits_indices_padded) assert logits_indices_padded is not None num_logits_indices = layer_attn_metadata.num_logits_indices