diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 109b65d92cf9..04d6cde555e2 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -44,7 +44,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsV0Only +from .interfaces import SupportsQuant, SupportsV0Only from .utils import maybe_prefix logger = logging.get_logger(__name__) @@ -697,7 +697,7 @@ class BartDecoder(nn.Module): return hidden_states -class BartModel(nn.Module): +class BartModel(nn.Module, SupportsQuant): _tied_weights_keys = [ "encoder.embed_tokens.weight", "decoder.embed_tokens.weight" ] @@ -763,7 +763,8 @@ class BartModel(nn.Module): return decoder_outputs -class BartForConditionalGeneration(nn.Module, SupportsV0Only): +class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant): + packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} base_model_prefix = "model" def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):