mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:54:56 +08:00
[Feature] EPLB on Qwen3VLMoe and CompressedTensorsWNA16MoEMethod (#28849)
This commit is contained in:
parent
0075bfffd4
commit
8e38e99829
@ -1921,9 +1921,20 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `CompressedTensorsWNA16MoEMethod` yet."
|
||||
)
|
||||
if expert_load_view is None:
|
||||
raise ValueError("enable_eplb=True requiere expert_load_view != None")
|
||||
if logical_to_physical_map is None:
|
||||
raise ValueError(
|
||||
"enable_eplb=True requiere logical_to_physical_map != None"
|
||||
)
|
||||
if logical_replica_count is None:
|
||||
raise ValueError(
|
||||
"enable_eplb=True requiere logical_replica_count != None"
|
||||
)
|
||||
if not isinstance(layer, FusedMoE):
|
||||
raise TypeError(
|
||||
"EPLB is only supported when `layer` is a instance of FusedMoE."
|
||||
)
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
@ -1940,6 +1951,12 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
num_fused_shared_experts=getattr(layer, "num_fused_shared_experts", 0),
|
||||
enable_eplb=enable_eplb,
|
||||
expert_map=expert_map,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count,
|
||||
)
|
||||
|
||||
return fused_experts(
|
||||
@ -1956,6 +1973,10 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_eplb(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
"""
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
@ -29,7 +29,9 @@ from collections.abc import Callable, Iterable
|
||||
from itertools import islice
|
||||
|
||||
import torch
|
||||
from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import Qwen3VLMoeConfig
|
||||
from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import (
|
||||
Qwen3VLMoeConfig,
|
||||
)
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
@ -44,7 +46,12 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel
|
||||
from .interfaces import MixtureOfExperts
|
||||
from .qwen3_moe import (
|
||||
Qwen3MoeForCausalLM,
|
||||
Qwen3MoeModel,
|
||||
Qwen3MoeSparseMoeBlock,
|
||||
)
|
||||
from .qwen3_vl import (
|
||||
Qwen3_VisionTransformer,
|
||||
Qwen3VLDummyInputsBuilder,
|
||||
@ -344,12 +351,56 @@ class Qwen3MoeLLMForCausalLM(Qwen3MoeForCausalLM):
|
||||
)
|
||||
|
||||
|
||||
class Qwen3VLMoeMixtureOfExperts(MixtureOfExperts):
|
||||
def update_physical_experts_metadata(
|
||||
self,
|
||||
num_physical_experts: int,
|
||||
num_local_physical_experts: int,
|
||||
) -> None:
|
||||
assert self.num_local_physical_experts == num_local_physical_experts
|
||||
self.num_physical_experts = num_physical_experts
|
||||
self.num_local_physical_experts = num_local_physical_experts
|
||||
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
|
||||
for layer in self.language_model.model.layers:
|
||||
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
|
||||
moe = layer.mlp
|
||||
moe.n_local_physical_experts = num_local_physical_experts
|
||||
moe.n_physical_experts = num_physical_experts
|
||||
moe.n_redundant_experts = self.num_redundant_experts
|
||||
moe.experts.update_expert_map()
|
||||
|
||||
def set_moe_parameters(self):
|
||||
self.expert_weights = []
|
||||
|
||||
self.moe_layers = []
|
||||
example_moe = None
|
||||
for layer in self.language_model.model.layers:
|
||||
if hasattr(layer, "mlp") and isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
|
||||
example_moe = layer.mlp
|
||||
self.moe_layers.append(layer.mlp.experts)
|
||||
|
||||
if example_moe is None:
|
||||
raise RuntimeError("No Qwen3Moe layer found in the language_model.")
|
||||
|
||||
# Set MoE hyperparameters
|
||||
self.num_moe_layers = len(self.moe_layers)
|
||||
self.num_expert_groups = 1
|
||||
self.num_shared_experts = 0
|
||||
self.num_logical_experts = example_moe.n_logical_experts
|
||||
self.num_physical_experts = example_moe.n_physical_experts
|
||||
self.num_local_physical_experts = example_moe.n_local_physical_experts
|
||||
self.num_routed_experts = example_moe.n_routed_experts
|
||||
self.num_redundant_experts = example_moe.n_redundant_experts
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
Qwen3VLMultiModalProcessor,
|
||||
info=Qwen3VLMoeProcessingInfo,
|
||||
dummy_inputs=Qwen3VLDummyInputsBuilder,
|
||||
)
|
||||
class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
|
||||
class Qwen3VLMoeForConditionalGeneration(
|
||||
Qwen3VLForConditionalGeneration, Qwen3VLMoeMixtureOfExperts
|
||||
):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -413,3 +464,6 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
|
||||
self.deepstack_input_embeds = None
|
||||
self.visual_dim = config.vision_config.out_hidden_size
|
||||
self.multiscale_dim = self.visual_dim * self.deepstack_num_level
|
||||
|
||||
# Set MoE hyperparameters
|
||||
self.set_moe_parameters()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user