[SupportsQuant] Bert, Blip, Blip2, Bloom (#15573)

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
Kyle Sayers 2025-04-03 11:23:19 -04:00 committed by GitHub
parent 84884cd9ac
commit 421c462948
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 16 additions and 9 deletions

View File

@ -26,7 +26,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.transformers_utils.config import (
get_cross_encoder_activation_function)
from .interfaces import SupportsCrossEncoding, SupportsV0Only
from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
from .utils import WeightsMapper, maybe_prefix
@ -313,7 +313,8 @@ class BertOutput(nn.Module):
return hidden_states
class BertModel(nn.Module):
class BertModel(nn.Module, SupportsQuant):
packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]}
def __init__(self,
*,
@ -385,7 +386,7 @@ class BertModel(nn.Module):
return loaded_params
class BertEmbeddingModel(nn.Module, SupportsV0Only):
class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
"""A model that uses Bert to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for
@ -443,7 +444,8 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only):
softmax=False)
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding):
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
SupportsQuant):
"""A model that uses Bert to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for

View File

@ -16,6 +16,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from .interfaces import SupportsQuant
def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
assert image_size % patch_size == 0
@ -243,9 +245,10 @@ class BlipEncoder(nn.Module):
return hidden_states
class BlipVisionModel(nn.Module):
class BlipVisionModel(nn.Module, SupportsQuant):
config_class = BlipVisionConfig
main_input_name = "pixel_values"
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
def __init__(
self,

View File

@ -24,7 +24,8 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from .blip import BlipVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
SupportsQuant)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
@ -498,7 +499,8 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor,
info=Blip2ProcessingInfo,
dummy_inputs=Blip2DummyInputsBuilder)
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
SupportsQuant):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

View File

@ -42,7 +42,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP, SupportsV0Only
from .interfaces import SupportsPP, SupportsQuant, SupportsV0Only
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@ -279,7 +279,7 @@ class BloomModel(nn.Module):
return hidden_states
class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only):
class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()