[Quant] Aria SupportsQuant (#13416)

This commit is contained in:
Kyle Sayers 2025-02-18 00:51:09 -05:00 committed by GitHub
parent ac19b519ed
commit d1b649f1ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -36,7 +36,7 @@ from .idefics2_vision_model import Idefics2VisionConfig
from .idefics2_vision_model import ( from .idefics2_vision_model import (
Idefics2VisionTransformer as Idefics3VisionTransformer) Idefics2VisionTransformer as Idefics3VisionTransformer)
# yapf: enable # yapf: enable
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal, SupportsQuant
from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
is_pp_missing_parameter, maybe_prefix, is_pp_missing_parameter, maybe_prefix,
@ -53,7 +53,8 @@ class AriaImagePixelInputs(TypedDict):
""" """
class AriaVisionTransformer(Idefics3VisionTransformer): class AriaVisionTransformer(Idefics3VisionTransformer, SupportsQuant):
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
def __init__( def __init__(
self, self,
@ -304,11 +305,17 @@ class AriaTextDecoderLayer(LlamaDecoderLayer):
self.mlp = AriaTextMoELayer(config, quant_config=quant_config) self.mlp = AriaTextMoELayer(config, quant_config=quant_config)
class AriaTextModel(LlamaModel): class AriaTextModel(LlamaModel, SupportsQuant):
""" """
Custom LlamaModel for the AriaMoE model which modifies the standard Custom LlamaModel for the AriaMoE model which modifies the standard
LlamaModel by replacing the `LlamaDecoderLayer` with `MoEDecoderLayer`. LlamaModel by replacing the `LlamaDecoderLayer` with `MoEDecoderLayer`.
""" """
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
"experts.w13_weight": ["experts.fc1.weight"],
"experts.w2_weight": ["experts.fc2.weight"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, super().__init__(vllm_config=vllm_config,