diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index 0eb6562bec6c..1b925742c300 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -30,6 +30,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import ( FusedMoEModularMethod, ) +from .utils import _get_lora_device + class FusedMoEWithLoRA(BaseLayerWithLoRA): def __init__(self, base_layer: FusedMoE) -> None: @@ -41,7 +43,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ) self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tensor_model_parallel_rank() - self.device = base_layer.w2_weight.device + self.device = _get_lora_device(base_layer) self._w13_slices = 2 self._inject_lora_into_fused_moe() diff --git a/vllm/lora/layers/utils.py b/vllm/lora/layers/utils.py index 2da90f180ee7..74403240f6cc 100644 --- a/vllm/lora/layers/utils.py +++ b/vllm/lora/layers/utils.py @@ -33,6 +33,15 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device: # HQQ marlin elif hasattr(base_layer, "W_q"): return base_layer.W_q.device + # MoE layer + elif hasattr(base_layer, "w2_weight"): + return base_layer.w2_weight.device + # MoE Compressed Tensor + elif hasattr(base_layer, "w2_weight_packed"): + return base_layer.w2_weight_packed.device + # MoE GPTQ/AWQ/GGUF + elif hasattr(base_layer, "w2_qweight"): + return base_layer.w2_qweight.device else: raise ValueError(f"Unsupported base layer: {base_layer}")