mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 13:44:58 +08:00
[Quant] Aria SupportsQuant (#13416)
This commit is contained in:
parent
ac19b519ed
commit
d1b649f1ef
@ -36,7 +36,7 @@ from .idefics2_vision_model import Idefics2VisionConfig
|
|||||||
from .idefics2_vision_model import (
|
from .idefics2_vision_model import (
|
||||||
Idefics2VisionTransformer as Idefics3VisionTransformer)
|
Idefics2VisionTransformer as Idefics3VisionTransformer)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from .interfaces import SupportsMultiModal
|
from .interfaces import SupportsMultiModal, SupportsQuant
|
||||||
from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel
|
from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel
|
||||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||||
is_pp_missing_parameter, maybe_prefix,
|
is_pp_missing_parameter, maybe_prefix,
|
||||||
@ -53,7 +53,8 @@ class AriaImagePixelInputs(TypedDict):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class AriaVisionTransformer(Idefics3VisionTransformer):
|
class AriaVisionTransformer(Idefics3VisionTransformer, SupportsQuant):
|
||||||
|
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -304,11 +305,17 @@ class AriaTextDecoderLayer(LlamaDecoderLayer):
|
|||||||
self.mlp = AriaTextMoELayer(config, quant_config=quant_config)
|
self.mlp = AriaTextMoELayer(config, quant_config=quant_config)
|
||||||
|
|
||||||
|
|
||||||
class AriaTextModel(LlamaModel):
|
class AriaTextModel(LlamaModel, SupportsQuant):
|
||||||
"""
|
"""
|
||||||
Custom LlamaModel for the AriaMoE model which modifies the standard
|
Custom LlamaModel for the AriaMoE model which modifies the standard
|
||||||
LlamaModel by replacing the `LlamaDecoderLayer` with `MoEDecoderLayer`.
|
LlamaModel by replacing the `LlamaDecoderLayer` with `MoEDecoderLayer`.
|
||||||
"""
|
"""
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||||
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||||
|
"experts.w13_weight": ["experts.fc1.weight"],
|
||||||
|
"experts.w2_weight": ["experts.fc2.weight"],
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__(vllm_config=vllm_config,
|
super().__init__(vllm_config=vllm_config,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user