Fix failing MyGemma2Embedding test (#13820)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-02-25 20:33:03 +00:00 committed by GitHub
parent f75aa72732
commit 34e3494e70
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,11 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Iterable, List, Optional, Tuple, Union
from typing import Iterable, Optional, Tuple, Union
import torch
import torch.nn as nn
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.models.gemma2 import Gemma2Model
@ -37,16 +36,12 @@ class MyGemma2Embedding(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(
input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)