mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-20 22:14:43 +08:00
[SupportsQuant] Bert, Blip, Blip2, Bloom (#15573)
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
parent
84884cd9ac
commit
421c462948
@ -26,7 +26,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.transformers_utils.config import (
|
||||
get_cross_encoder_activation_function)
|
||||
|
||||
from .interfaces import SupportsCrossEncoding, SupportsV0Only
|
||||
from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
|
||||
from .utils import WeightsMapper, maybe_prefix
|
||||
|
||||
|
||||
@ -313,7 +313,8 @@ class BertOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertModel(nn.Module):
|
||||
class BertModel(nn.Module, SupportsQuant):
|
||||
packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]}
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
@ -385,7 +386,7 @@ class BertModel(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class BertEmbeddingModel(nn.Module, SupportsV0Only):
|
||||
class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
|
||||
"""A model that uses Bert to provide embedding functionalities.
|
||||
|
||||
This class encapsulates the BertModel and provides an interface for
|
||||
@ -443,7 +444,8 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only):
|
||||
softmax=False)
|
||||
|
||||
|
||||
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||
SupportsQuant):
|
||||
"""A model that uses Bert to provide embedding functionalities.
|
||||
|
||||
This class encapsulates the BertModel and provides an interface for
|
||||
|
||||
@ -16,6 +16,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from .interfaces import SupportsQuant
|
||||
|
||||
|
||||
def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
||||
assert image_size % patch_size == 0
|
||||
@ -243,9 +245,10 @@ class BlipEncoder(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BlipVisionModel(nn.Module):
|
||||
class BlipVisionModel(nn.Module, SupportsQuant):
|
||||
config_class = BlipVisionConfig
|
||||
main_input_name = "pixel_values"
|
||||
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@ -24,7 +24,8 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .blip import BlipVisionModel
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
|
||||
SupportsQuant)
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
|
||||
@ -498,7 +499,8 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
|
||||
@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor,
|
||||
info=Blip2ProcessingInfo,
|
||||
dummy_inputs=Blip2DummyInputsBuilder)
|
||||
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
SupportsQuant):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
|
||||
|
||||
@ -42,7 +42,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP, SupportsV0Only
|
||||
from .interfaces import SupportsPP, SupportsQuant, SupportsV0Only
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
@ -279,7 +279,7 @@ class BloomModel(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only):
|
||||
class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user