mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:15:01 +08:00
[Bugfix] Fix getting device for MoE LoRA (#29475)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
11ea5ec1ff
commit
c069086b9c
@ -30,6 +30,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
|
||||
FusedMoEModularMethod,
|
||||
)
|
||||
|
||||
from .utils import _get_lora_device
|
||||
|
||||
|
||||
class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
def __init__(self, base_layer: FusedMoE) -> None:
|
||||
@ -41,7 +43,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
)
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.device = base_layer.w2_weight.device
|
||||
self.device = _get_lora_device(base_layer)
|
||||
self._w13_slices = 2
|
||||
self._inject_lora_into_fused_moe()
|
||||
|
||||
|
||||
@ -33,6 +33,15 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device:
|
||||
# HQQ marlin
|
||||
elif hasattr(base_layer, "W_q"):
|
||||
return base_layer.W_q.device
|
||||
# MoE layer
|
||||
elif hasattr(base_layer, "w2_weight"):
|
||||
return base_layer.w2_weight.device
|
||||
# MoE Compressed Tensor
|
||||
elif hasattr(base_layer, "w2_weight_packed"):
|
||||
return base_layer.w2_weight_packed.device
|
||||
# MoE GPTQ/AWQ/GGUF
|
||||
elif hasattr(base_layer, "w2_qweight"):
|
||||
return base_layer.w2_qweight.device
|
||||
else:
|
||||
raise ValueError(f"Unsupported base layer: {base_layer}")
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user