[Feature] EPLB on Qwen3VLMoe and CompressedTensorsWNA16MoEMethod (#28849)

This commit is contained in:
JartX 2025-11-20 00:30:08 +01:00 committed by GitHub
parent 0075bfffd4
commit 8e38e99829
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 82 additions and 7 deletions

View File

@ -1921,9 +1921,20 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `CompressedTensorsWNA16MoEMethod` yet."
)
if expert_load_view is None:
raise ValueError("enable_eplb=True requiere expert_load_view != None")
if logical_to_physical_map is None:
raise ValueError(
"enable_eplb=True requiere logical_to_physical_map != None"
)
if logical_replica_count is None:
raise ValueError(
"enable_eplb=True requiere logical_replica_count != None"
)
if not isinstance(layer, FusedMoE):
raise TypeError(
"EPLB is only supported when `layer` is a instance of FusedMoE."
)
from vllm.model_executor.layers.fused_moe import fused_experts
@ -1940,6 +1951,12 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
num_fused_shared_experts=getattr(layer, "num_fused_shared_experts", 0),
enable_eplb=enable_eplb,
expert_map=expert_map,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
return fused_experts(
@ -1956,6 +1973,10 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
quant_config=self.moe_quant_config,
)
@property
def supports_eplb(self) -> bool:
return True
class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
"""

View File

@ -15,7 +15,7 @@
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
@ -29,7 +29,9 @@ from collections.abc import Callable, Iterable
from itertools import islice
import torch
from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import Qwen3VLMoeConfig
from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import (
Qwen3VLMoeConfig,
)
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
@ -44,7 +46,12 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors
from .qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel
from .interfaces import MixtureOfExperts
from .qwen3_moe import (
Qwen3MoeForCausalLM,
Qwen3MoeModel,
Qwen3MoeSparseMoeBlock,
)
from .qwen3_vl import (
Qwen3_VisionTransformer,
Qwen3VLDummyInputsBuilder,
@ -344,12 +351,56 @@ class Qwen3MoeLLMForCausalLM(Qwen3MoeForCausalLM):
)
class Qwen3VLMoeMixtureOfExperts(MixtureOfExperts):
def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
) -> None:
assert self.num_local_physical_experts == num_local_physical_experts
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
for layer in self.language_model.model.layers:
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
moe = layer.mlp
moe.n_local_physical_experts = num_local_physical_experts
moe.n_physical_experts = num_physical_experts
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
def set_moe_parameters(self):
self.expert_weights = []
self.moe_layers = []
example_moe = None
for layer in self.language_model.model.layers:
if hasattr(layer, "mlp") and isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
example_moe = layer.mlp
self.moe_layers.append(layer.mlp.experts)
if example_moe is None:
raise RuntimeError("No Qwen3Moe layer found in the language_model.")
# Set MoE hyperparameters
self.num_moe_layers = len(self.moe_layers)
self.num_expert_groups = 1
self.num_shared_experts = 0
self.num_logical_experts = example_moe.n_logical_experts
self.num_physical_experts = example_moe.n_physical_experts
self.num_local_physical_experts = example_moe.n_local_physical_experts
self.num_routed_experts = example_moe.n_routed_experts
self.num_redundant_experts = example_moe.n_redundant_experts
@MULTIMODAL_REGISTRY.register_processor(
Qwen3VLMultiModalProcessor,
info=Qwen3VLMoeProcessingInfo,
dummy_inputs=Qwen3VLDummyInputsBuilder,
)
class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
class Qwen3VLMoeForConditionalGeneration(
Qwen3VLForConditionalGeneration, Qwen3VLMoeMixtureOfExperts
):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@ -413,3 +464,6 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
self.deepstack_input_embeds = None
self.visual_dim = config.vision_config.out_hidden_size
self.multiscale_dim = self.visual_dim * self.deepstack_num_level
# Set MoE hyperparameters
self.set_moe_parameters()