diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 55c6e3d7f555..18c075cfa482 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -380,9 +380,9 @@ Specified using `--task generate`. | `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | | | | `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2ForCausalLM` | QwQ, Qwen2 | `Qwen/QwQ-32B-Preview`, `Qwen/Qwen2-7B-Instruct`, `Qwen/Qwen2-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | | ✅︎ | ✅︎ | +| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen3ForCausalLM` | Qwen3 | `Qwen/Qwen3-8B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | | ✅︎ | ✅︎ | +| `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | | ✅︎ | | `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ | ✅︎ | | `SolarForCausalLM` | Solar Pro | `upstage/solar-pro-preview-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | diff --git a/vllm/lora/models.py b/vllm/lora/models.py index bff4e912578c..521bb079da41 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -29,6 +29,7 @@ from vllm.lora.utils import (from_layer, from_layer_logits_processor, get_supported_lora_modules, is_regex_target_modules, parse_fine_tuned_lora_name, replace_submodule) +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.models import SupportsLoRA, supports_multimodal from vllm.model_executor.models.interfaces import is_pooling_model @@ -60,6 +61,17 @@ def get_lora_id(): return _GLOBAL_LORA_ID +def is_moe_model(model: nn.Module) -> bool: + """Checks if the model contains FusedMoE layers and warns the user.""" + if any(isinstance(module, FusedMoE) for module in model.modules()): + logger.warning_once( + "For MoE models, vLLM currently does not support fused MoE LoRA " + "inference. Please ensure that the loaded LoRA model does not " + "contain expert weights.") + return True + return False + + class LoRAModel(AdapterModel): """A LoRA fine-tuned model.""" @@ -375,6 +387,7 @@ class LoRAModelManager(AdapterModelManager): # text modules (e.g. ChatGLM) and hasattr(self.model, "get_mm_mapping")) self.is_pooling_model = is_pooling_model(self.model) + self.is_moe_model = is_moe_model(self.model) self.packed_modules: dict[str, list[str]] = {} self.modules: dict[str, BaseLayerWithLoRA] = {} # Dict instead of a set for compatibility with LRUCache. diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 84bae87804c1..b061e2f69a6c 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -53,7 +53,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsPP +from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -448,8 +448,7 @@ class Qwen2MoeModel(nn.Module): if weight_name not in name: continue name = name.replace(weight_name, param_name) - if "layers.13.mlp.experts.w2_weight" in name: - pass + # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue @@ -494,7 +493,7 @@ class Qwen2MoeModel(nn.Module): return loaded_params -class Qwen2MoeForCausalLM(nn.Module, SupportsPP): +class Qwen2MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): fall_back_to_pt_during_load = False packed_modules_mapping = { diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 0f749b3e38f1..12899c28016b 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -50,7 +50,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsPP +from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -482,7 +482,7 @@ class Qwen3MoeModel(nn.Module): return loaded_params -class Qwen3MoeForCausalLM(nn.Module, SupportsPP): +class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): packed_modules_mapping = { "qkv_proj": [ "q_proj",