[Bugfix] Fix getting device for MoE LoRA (#29475)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-11-27 15:16:07 +08:00 committed by GitHub
parent 11ea5ec1ff
commit c069086b9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 1 deletions

View File

@ -30,6 +30,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
FusedMoEModularMethod,
)
from .utils import _get_lora_device
class FusedMoEWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: FusedMoE) -> None:
@ -41,7 +43,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
)
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.device = base_layer.w2_weight.device
self.device = _get_lora_device(base_layer)
self._w13_slices = 2
self._inject_lora_into_fused_moe()

View File

@ -33,6 +33,15 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device:
# HQQ marlin
elif hasattr(base_layer, "W_q"):
return base_layer.W_q.device
# MoE layer
elif hasattr(base_layer, "w2_weight"):
return base_layer.w2_weight.device
# MoE Compressed Tensor
elif hasattr(base_layer, "w2_weight_packed"):
return base_layer.w2_weight_packed.device
# MoE GPTQ/AWQ/GGUF
elif hasattr(base_layer, "w2_qweight"):
return base_layer.w2_qweight.device
else:
raise ValueError(f"Unsupported base layer: {base_layer}")