mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:01:40 +08:00
[Misc] Remove redundant TypeVar from base model (#12248)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
1f1542afa9
commit
f2e9f2a3be
@ -3,7 +3,6 @@ from typing import (TYPE_CHECKING, List, Optional, Protocol, Type, Union,
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
from typing_extensions import TypeIs, TypeVar
|
||||
|
||||
from vllm.logger import init_logger
|
||||
@ -19,9 +18,6 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# The type of HF config
|
||||
C_co = TypeVar("C_co", bound=PretrainedConfig, covariant=True)
|
||||
|
||||
# The type of hidden states
|
||||
# Currently, T = torch.Tensor for all models except for Medusa
|
||||
# which has T = List[torch.Tensor]
|
||||
@ -34,7 +30,7 @@ T_co = TypeVar("T_co", default=torch.Tensor, covariant=True)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class VllmModel(Protocol[C_co, T_co]):
|
||||
class VllmModel(Protocol[T_co]):
|
||||
"""The interface required for all models in vLLM."""
|
||||
|
||||
def __init__(
|
||||
@ -97,7 +93,7 @@ def is_vllm_model(
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class VllmModelForTextGeneration(VllmModel[C_co, T], Protocol[C_co, T]):
|
||||
class VllmModelForTextGeneration(VllmModel[T], Protocol[T]):
|
||||
"""The interface required for all generative models in vLLM."""
|
||||
|
||||
def compute_logits(
|
||||
@ -143,7 +139,7 @@ def is_text_generation_model(
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class VllmModelForPooling(VllmModel[C_co, T], Protocol[C_co, T]):
|
||||
class VllmModelForPooling(VllmModel[T], Protocol[T]):
|
||||
"""The interface required for all pooling models in vLLM."""
|
||||
|
||||
def pooler(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user