mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 03:04:56 +08:00
[Bugfix] Fix getting device for MoE LoRA (#29475)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
11ea5ec1ff
commit
c069086b9c
@ -30,6 +30,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
|
|||||||
FusedMoEModularMethod,
|
FusedMoEModularMethod,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .utils import _get_lora_device
|
||||||
|
|
||||||
|
|
||||||
class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||||
def __init__(self, base_layer: FusedMoE) -> None:
|
def __init__(self, base_layer: FusedMoE) -> None:
|
||||||
@ -41,7 +43,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
)
|
)
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.tp_rank = get_tensor_model_parallel_rank()
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
self.device = base_layer.w2_weight.device
|
self.device = _get_lora_device(base_layer)
|
||||||
self._w13_slices = 2
|
self._w13_slices = 2
|
||||||
self._inject_lora_into_fused_moe()
|
self._inject_lora_into_fused_moe()
|
||||||
|
|
||||||
|
|||||||
@ -33,6 +33,15 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device:
|
|||||||
# HQQ marlin
|
# HQQ marlin
|
||||||
elif hasattr(base_layer, "W_q"):
|
elif hasattr(base_layer, "W_q"):
|
||||||
return base_layer.W_q.device
|
return base_layer.W_q.device
|
||||||
|
# MoE layer
|
||||||
|
elif hasattr(base_layer, "w2_weight"):
|
||||||
|
return base_layer.w2_weight.device
|
||||||
|
# MoE Compressed Tensor
|
||||||
|
elif hasattr(base_layer, "w2_weight_packed"):
|
||||||
|
return base_layer.w2_weight_packed.device
|
||||||
|
# MoE GPTQ/AWQ/GGUF
|
||||||
|
elif hasattr(base_layer, "w2_qweight"):
|
||||||
|
return base_layer.w2_qweight.device
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported base layer: {base_layer}")
|
raise ValueError(f"Unsupported base layer: {base_layer}")
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user