import importlib import pickle import subprocess import sys import tempfile from abc import ABC, abstractmethod from dataclasses import dataclass, field from functools import lru_cache from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union import cloudpickle import torch.nn as nn from vllm.logger import init_logger from vllm.utils import is_hip from .interfaces import (has_inner_state, is_attention_free, supports_multimodal, supports_pp) from .interfaces_base import is_embedding_model, is_text_generation_model logger = init_logger(__name__) # yapf: disable _TEXT_GENERATION_MODELS = { # [Decoder-only] "AquilaModel": ("llama", "LlamaForCausalLM"), "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b "BloomForCausalLM": ("bloom", "BloomForCausalLM"), # ChatGLMModel supports multimodal "CohereForCausalLM": ("commandr", "CohereForCausalLM"), "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"), "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"), "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"), "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"), "FalconForCausalLM": ("falcon", "FalconForCausalLM"), "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"), "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"), "GraniteForCausalLM": ("granite", "GraniteForCausalLM"), "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"), "InternLMForCausalLM": ("llama", "LlamaForCausalLM"), "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"), "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), "JambaForCausalLM": ("jamba", "JambaForCausalLM"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), # For decapoda-research/llama-* "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "MambaForCausalLM": ("mamba", "MambaForCausalLM"), "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"), "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"), # transformers's mpt class has lower case "MptForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"), "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"), "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"), "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"), "OPTForCausalLM": ("opt", "OPTForCausalLM"), "OrionForCausalLM": ("orion", "OrionForCausalLM"), "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"), "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), # QWenLMHeadModel supports multimodal "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), "RWForCausalLM": ("falcon", "FalconForCausalLM"), "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), "SolarForCausalLM": ("solar", "SolarForCausalLM"), "XverseForCausalLM": ("xverse", "XverseForCausalLM"), # [Encoder-decoder] "BartModel": ("bart", "BartForConditionalGeneration"), "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), } _EMBEDDING_MODELS = { # [Text-only] "BertModel": ("bert", "BertEmbeddingModel"), "Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"), "MistralModel": ("llama", "LlamaEmbeddingModel"), "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), # [Multimodal] "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), } _MULTIMODAL_MODELS = { # [Decoder-only] "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"), "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501 "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), "InternVLChatModel": ("internvl", "InternVLChatModel"), "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"), "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501 "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), # noqa: E501 "MiniCPMV": ("minicpmv", "MiniCPMV"), "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"), "NVLM_D": ("nvlm_d", "NVLM_D_Model"), "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501 "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 "UltravoxModel": ("ultravox", "UltravoxModel"), # [Encoder-decoder] "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 } _SPECULATIVE_DECODING_MODELS = { "EAGLEModel": ("eagle", "EAGLE"), "MedusaModel": ("medusa", "Medusa"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), } # yapf: enable _VLLM_MODELS = { **_TEXT_GENERATION_MODELS, **_EMBEDDING_MODELS, **_MULTIMODAL_MODELS, **_SPECULATIVE_DECODING_MODELS, } # Models not supported by ROCm. _ROCM_UNSUPPORTED_MODELS: List[str] = [] # Models partially supported by ROCm. # Architecture -> Reason. _ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in " "Triton flash attention. For half-precision SWA support, " "please use CK flash attention by setting " "`VLLM_USE_TRITON_FLASH_ATTN=0`") _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = { "Qwen2ForCausalLM": _ROCM_SWA_REASON, "MistralForCausalLM": _ROCM_SWA_REASON, "MixtralForCausalLM": _ROCM_SWA_REASON, "PaliGemmaForConditionalGeneration": ("ROCm flash attention does not yet " "fully support 32-bit precision on PaliGemma"), "Phi3VForCausalLM": ("ROCm Triton flash attention may run into compilation errors due to " "excessive use of shared memory. If this happens, disable Triton FA " "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`") } @dataclass(frozen=True) class _ModelInfo: is_text_generation_model: bool is_embedding_model: bool supports_multimodal: bool supports_pp: bool has_inner_state: bool is_attention_free: bool @staticmethod def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": return _ModelInfo( is_text_generation_model=is_text_generation_model(model), is_embedding_model=is_embedding_model(model), supports_multimodal=supports_multimodal(model), supports_pp=supports_pp(model), has_inner_state=has_inner_state(model), is_attention_free=is_attention_free(model), ) class _BaseRegisteredModel(ABC): @abstractmethod def inspect_model_cls(self) -> _ModelInfo: raise NotImplementedError @abstractmethod def load_model_cls(self) -> Type[nn.Module]: raise NotImplementedError @dataclass(frozen=True) class _RegisteredModel(_BaseRegisteredModel): """ Represents a model that has already been imported in the main process. """ interfaces: _ModelInfo model_cls: Type[nn.Module] @staticmethod def from_model_cls(model_cls: Type[nn.Module]): return _RegisteredModel( interfaces=_ModelInfo.from_model_cls(model_cls), model_cls=model_cls, ) def inspect_model_cls(self) -> _ModelInfo: return self.interfaces def load_model_cls(self) -> Type[nn.Module]: return self.model_cls @dataclass(frozen=True) class _LazyRegisteredModel(_BaseRegisteredModel): """ Represents a model that has not been imported in the main process. """ module_name: str class_name: str # Performed in another process to avoid initializing CUDA def inspect_model_cls(self) -> _ModelInfo: return _run_in_subprocess( lambda: _ModelInfo.from_model_cls(self.load_model_cls())) def load_model_cls(self) -> Type[nn.Module]: mod = importlib.import_module(self.module_name) return getattr(mod, self.class_name) @lru_cache(maxsize=128) def _try_load_model_cls( model_arch: str, model: _BaseRegisteredModel, ) -> Optional[Type[nn.Module]]: if is_hip(): if model_arch in _ROCM_UNSUPPORTED_MODELS: raise ValueError(f"Model architecture '{model_arch}' is not " "supported by ROCm for now.") if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS: msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch] logger.warning( "Model architecture '%s' is partially " "supported by ROCm: %s", model_arch, msg) try: return model.load_model_cls() except Exception: logger.exception("Error in loading model architecture '%s'", model_arch) return None @lru_cache(maxsize=128) def _try_inspect_model_cls( model_arch: str, model: _BaseRegisteredModel, ) -> Optional[_ModelInfo]: try: return model.inspect_model_cls() except Exception: logger.exception("Error in inspecting model architecture '%s'", model_arch) return None @dataclass class _ModelRegistry: # Keyed by model_arch models: Dict[str, _BaseRegisteredModel] = field(default_factory=dict) def get_supported_archs(self) -> List[str]: return list(self.models.keys()) def register_model( self, model_arch: str, model_cls: Union[Type[nn.Module], str], ) -> None: """ Register an external model to be used in vLLM. :code:`model_cls` can be either: - A :class:`torch.nn.Module` class directly referencing the model. - A string in the format :code:`:` which can be used to lazily import the model. This is useful to avoid initializing CUDA when importing the model and thus the related error :code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`. """ if model_arch in self.models: logger.warning( "Model architecture %s is already registered, and will be " "overwritten by the new model class %s.", model_arch, model_cls) if isinstance(model_cls, str): split_str = model_cls.split(":") if len(split_str) != 2: msg = "Expected a string in the format `:`" raise ValueError(msg) model = _LazyRegisteredModel(*split_str) else: model = _RegisteredModel.from_model_cls(model_cls) self.models[model_arch] = model def _raise_for_unsupported(self, architectures: List[str]): all_supported_archs = self.get_supported_archs() raise ValueError( f"Model architectures {architectures} are not supported for now. " f"Supported architectures: {all_supported_archs}") def _try_load_model_cls(self, model_arch: str) -> Optional[Type[nn.Module]]: if model_arch not in self.models: return None return _try_load_model_cls(model_arch, self.models[model_arch]) def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]: if model_arch not in self.models: return None return _try_inspect_model_cls(model_arch, self.models[model_arch]) def _normalize_archs( self, architectures: Union[str, List[str]], ) -> List[str]: if isinstance(architectures, str): architectures = [architectures] if not architectures: logger.warning("No model architectures are specified") return architectures def inspect_model_cls( self, architectures: Union[str, List[str]], ) -> _ModelInfo: architectures = self._normalize_archs(architectures) for arch in architectures: model_info = self._try_inspect_model_cls(arch) if model_info is not None: return model_info return self._raise_for_unsupported(architectures) def resolve_model_cls( self, architectures: Union[str, List[str]], ) -> Tuple[Type[nn.Module], str]: architectures = self._normalize_archs(architectures) for arch in architectures: model_cls = self._try_load_model_cls(arch) if model_cls is not None: return (model_cls, arch) return self._raise_for_unsupported(architectures) def is_text_generation_model( self, architectures: Union[str, List[str]], ) -> bool: return self.inspect_model_cls(architectures).is_text_generation_model def is_embedding_model( self, architectures: Union[str, List[str]], ) -> bool: return self.inspect_model_cls(architectures).is_embedding_model def is_multimodal_model( self, architectures: Union[str, List[str]], ) -> bool: return self.inspect_model_cls(architectures).supports_multimodal def is_pp_supported_model( self, architectures: Union[str, List[str]], ) -> bool: return self.inspect_model_cls(architectures).supports_pp def model_has_inner_state(self, architectures: Union[str, List[str]]) -> bool: return self.inspect_model_cls(architectures).has_inner_state def is_attention_free_model(self, architectures: Union[str, List[str]]) -> bool: return self.inspect_model_cls(architectures).is_attention_free ModelRegistry = _ModelRegistry({ model_arch: _LazyRegisteredModel( module_name=f"vllm.model_executor.models.{mod_relname}", class_name=cls_name, ) for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items() }) _T = TypeVar("_T") def _run_in_subprocess(fn: Callable[[], _T]) -> _T: with tempfile.NamedTemporaryFile() as output_file: # `cloudpickle` allows pickling lambda functions directly input_bytes = cloudpickle.dumps((fn, output_file.name)) # cannot use `sys.executable __file__` here because the script # contains relative imports returned = subprocess.run( [sys.executable, "-m", "vllm.model_executor.models.registry"], input=input_bytes, capture_output=True) # check if the subprocess is successful try: returned.check_returncode() except Exception as e: # wrap raised exception to provide more information raise RuntimeError(f"Error raised in subprocess:\n" f"{returned.stderr.decode()}") from e with open(output_file.name, "rb") as f: return pickle.load(f) def _run() -> None: # Setup plugins from vllm.plugins import load_general_plugins load_general_plugins() fn, output_file = pickle.loads(sys.stdin.buffer.read()) result = fn() with open(output_file, "wb") as f: f.write(pickle.dumps(result)) if __name__ == "__main__": _run()