mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:15:01 +08:00
Rename fallback model and refactor supported models section (#15829)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
b7b7676d67
commit
a76f547e11
@ -77,9 +77,9 @@ getting_started/v1_user_guide
|
||||
:caption: Models
|
||||
:maxdepth: 1
|
||||
|
||||
models/supported_models
|
||||
models/generative_models
|
||||
models/pooling_models
|
||||
models/supported_models
|
||||
models/extensions/index
|
||||
:::
|
||||
|
||||
|
||||
@ -1,55 +1,28 @@
|
||||
(supported-models)=
|
||||
|
||||
# List of Supported Models
|
||||
# Supported Models
|
||||
|
||||
vLLM supports generative and pooling models across various tasks.
|
||||
vLLM supports [generative](generative-models) and [pooling](pooling-models) models across various tasks.
|
||||
If a model supports more than one task, you can set the task via the `--task` argument.
|
||||
|
||||
For each task, we list the model architectures that have been implemented in vLLM.
|
||||
Alongside each architecture, we include some popular models that use it.
|
||||
|
||||
## Loading a Model
|
||||
## Model Implementation
|
||||
|
||||
### HuggingFace Hub
|
||||
### vLLM
|
||||
|
||||
By default, vLLM loads models from [HuggingFace (HF) Hub](https://huggingface.co/models).
|
||||
If vLLM natively supports a model, its implementation can be found in <gh-file:vllm/model_executor/models>.
|
||||
|
||||
To determine whether a given model is natively supported, you can check the `config.json` file inside the HF repository.
|
||||
If the `"architectures"` field contains a model architecture listed below, then it should be natively supported.
|
||||
These models are what we list in <project:#supported-text-models> and <project:#supported-mm-models>.
|
||||
|
||||
Models do not _need_ to be natively supported to be used in vLLM.
|
||||
The <project:#transformers-fallback> enables you to run models directly using their Transformers implementation (or even remote code on the Hugging Face Model Hub!).
|
||||
(transformers-backend)=
|
||||
|
||||
:::{tip}
|
||||
The easiest way to check if your model is really supported at runtime is to run the program below:
|
||||
### Transformers
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
vLLM also supports model implementations that are available in Transformers. This does not currently work for all models, but most decoder language models are supported, and vision language model support is planned!
|
||||
|
||||
# For generative models (task=generate) only
|
||||
llm = LLM(model=..., task="generate") # Name or path of your model
|
||||
output = llm.generate("Hello, my name is")
|
||||
print(output)
|
||||
|
||||
# For pooling models (task={embed,classify,reward,score}) only
|
||||
llm = LLM(model=..., task="embed") # Name or path of your model
|
||||
output = llm.encode("Hello, my name is")
|
||||
print(output)
|
||||
```
|
||||
|
||||
If vLLM successfully returns text (for generative models) or hidden states (for pooling models), it indicates that your model is supported.
|
||||
:::
|
||||
|
||||
Otherwise, please refer to [Adding a New Model](#new-model) for instructions on how to implement your model in vLLM.
|
||||
Alternatively, you can [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/choose) to request vLLM support.
|
||||
|
||||
(transformers-fallback)=
|
||||
|
||||
### Transformers fallback
|
||||
|
||||
vLLM can fallback to model implementations that are available in Transformers. This does not work for all models for now, but most decoder language models are supported, and vision language model support is planned!
|
||||
|
||||
To check if the backend is Transformers, you can simply do this:
|
||||
To check if the modeling backend is Transformers, you can simply do this:
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
@ -69,16 +42,15 @@ vLLM may not fully optimise the Transformers implementation so you may see degra
|
||||
|
||||
#### Supported features
|
||||
|
||||
The Transformers fallback explicitly supports the following features:
|
||||
The Transformers modeling backend explicitly supports the following features:
|
||||
|
||||
- <project:#quantization-index> (except GGUF)
|
||||
- <project:#lora-adapter>
|
||||
- <project:#distributed-serving>
|
||||
|
||||
#### Remote code
|
||||
#### Remote Code
|
||||
|
||||
Earlier we mentioned that the Transformers fallback enables you to run remote code models directly in vLLM.
|
||||
If you are interested in this feature, this section is for you!
|
||||
If your model is neither supported natively by vLLM or Transformers, you can still run it in vLLM!
|
||||
|
||||
Simply set `trust_remote_code=True` and vLLM will run any model on the Model Hub that is compatible with Transformers.
|
||||
Provided that the model writer implements their model in a compatible way, this means that you can run new models before they are officially supported in Transformers or vLLM!
|
||||
@ -89,7 +61,7 @@ llm = LLM(model=..., task="generate", trust_remote_code=True) # Name or path of
|
||||
llm.apply_model(lambda model: print(model.__class__))
|
||||
```
|
||||
|
||||
To make your model compatible with the Transformers fallback, it needs:
|
||||
To make your model compatible with the Transformers backend, it needs:
|
||||
|
||||
```{code-block} python
|
||||
:caption: modeling_my_model.py
|
||||
@ -121,7 +93,9 @@ Here is what happens in the background:
|
||||
2. `MyModel` Python class is loaded from the `auto_map`, and we check that the model `_supports_attention_backend`.
|
||||
3. The `TransformersForCausalLM` backend is used. See <gh-file:vllm/model_executor/models/transformers.py>, which leverage `self.config._attn_implementation = "vllm"`, thus the need to use `ALL_ATTENTION_FUNCTION`.
|
||||
|
||||
To make your model compatible with tensor parallel, it needs:
|
||||
That's it!
|
||||
|
||||
For your model to be compatible with vLLM's tensor parallel and/or pipeline parallel features, you must add `base_model_tp_plan` and/or `base_model_pp_plan` to your model's config class:
|
||||
|
||||
```{code-block} python
|
||||
:caption: configuration_my_model.py
|
||||
@ -130,20 +104,65 @@ from transformers import PretrainedConfig
|
||||
|
||||
class MyConfig(PretrainedConfig):
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
...
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
```
|
||||
|
||||
- `base_model_tp_plan` is a `dict` that maps fully qualified layer name patterns to tensor parallel styles (currently only `"colwise"` and `"rowwise"` are supported).
|
||||
- `base_model_pp_plan` is a `dict` that maps direct child layer names to `tuple`s of `list`s of `str`s:
|
||||
* You only need to do this for layers which are not present on all pipeline stages
|
||||
* vLLM assumes that there will be only one `nn.ModuleList`, which is distributed across the pipeline stages
|
||||
* The `list` in the first element of the `tuple` contains the names of the input arguments
|
||||
* The `list` in the last element of the `tuple` contains the names of the variables the layer outputs to in your modeling code
|
||||
|
||||
## Loading a Model
|
||||
|
||||
### Hugging Face Hub
|
||||
|
||||
By default, vLLM loads models from [Hugging Face (HF) Hub](https://huggingface.co/models).
|
||||
|
||||
To determine whether a given model is natively supported, you can check the `config.json` file inside the HF repository.
|
||||
If the `"architectures"` field contains a model architecture listed below, then it should be natively supported.
|
||||
|
||||
Models do not _need_ to be natively supported to be used in vLLM.
|
||||
The <project:#transformers-backend> enables you to run models directly using their Transformers implementation (or even remote code on the Hugging Face Model Hub!).
|
||||
|
||||
:::{tip}
|
||||
`base_model_tp_plan` is a `dict` that maps fully qualified layer name patterns to tensor parallel styles (currently only `"colwise"` and `"rowwise"` are supported).
|
||||
The easiest way to check if your model is really supported at runtime is to run the program below:
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
|
||||
# For generative models (task=generate) only
|
||||
llm = LLM(model=..., task="generate") # Name or path of your model
|
||||
output = llm.generate("Hello, my name is")
|
||||
print(output)
|
||||
|
||||
# For pooling models (task={embed,classify,reward,score}) only
|
||||
llm = LLM(model=..., task="embed") # Name or path of your model
|
||||
output = llm.encode("Hello, my name is")
|
||||
print(output)
|
||||
```
|
||||
|
||||
If vLLM successfully returns text (for generative models) or hidden states (for pooling models), it indicates that your model is supported.
|
||||
:::
|
||||
|
||||
That's it!
|
||||
Otherwise, please refer to [Adding a New Model](#new-model) for instructions on how to implement your model in vLLM.
|
||||
Alternatively, you can [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/choose) to request vLLM support.
|
||||
|
||||
### ModelScope
|
||||
|
||||
To use models from [ModelScope](https://www.modelscope.cn) instead of HuggingFace Hub, set an environment variable:
|
||||
To use models from [ModelScope](https://www.modelscope.cn) instead of Hugging Face Hub, set an environment variable:
|
||||
|
||||
```shell
|
||||
export VLLM_USE_MODELSCOPE=True
|
||||
@ -165,6 +184,8 @@ output = llm.encode("Hello, my name is")
|
||||
print(output)
|
||||
```
|
||||
|
||||
(supported-text-models)=
|
||||
|
||||
## List of Text-only Language Models
|
||||
|
||||
### Generative Models
|
||||
@ -1066,7 +1087,7 @@ At vLLM, we are committed to facilitating the integration and support of third-p
|
||||
2. **Best-Effort Consistency**: While we aim to maintain a level of consistency between the models implemented in vLLM and other frameworks like transformers, complete alignment is not always feasible. Factors like acceleration techniques and the use of low-precision computations can introduce discrepancies. Our commitment is to ensure that the implemented models are functional and produce sensible results.
|
||||
|
||||
:::{tip}
|
||||
When comparing the output of `model.generate` from HuggingFace Transformers with the output of `llm.generate` from vLLM, note that the former reads the model's generation config file (i.e., [generation_config.json](https://github.com/huggingface/transformers/blob/19dabe96362803fb0a9ae7073d03533966598b17/src/transformers/generation/utils.py#L1945)) and applies the default parameters for generation, while the latter only uses the parameters passed to the function. Ensure all sampling parameters are identical when comparing outputs.
|
||||
When comparing the output of `model.generate` from Hugging Face Transformers with the output of `llm.generate` from vLLM, note that the former reads the model's generation config file (i.e., [generation_config.json](https://github.com/huggingface/transformers/blob/19dabe96362803fb0a9ae7073d03533966598b17/src/transformers/generation/utils.py#L1945)) and applies the default parameters for generation, while the latter only uses the parameters passed to the function. Ensure all sampling parameters are identical when comparing outputs.
|
||||
:::
|
||||
|
||||
3. **Issue Resolution and Model Updates**: Users are encouraged to report any bugs or issues they encounter with third-party models. Proposed fixes should be submitted via PRs, with a clear explanation of the problem and the rationale behind the proposed solution. If a fix for one model impacts another, we rely on the community to highlight and address these cross-model dependencies. Note: for bugfix PRs, it is good etiquette to inform the original author to seek their feedback.
|
||||
|
||||
@ -346,7 +346,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True),
|
||||
}
|
||||
|
||||
_FALLBACK_MODEL = {
|
||||
_TRANSFORMERS_MODELS = {
|
||||
"TransformersForCausalLM": _HfExamplesInfo("ArthurZ/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
|
||||
}
|
||||
|
||||
@ -356,7 +356,7 @@ _EXAMPLE_MODELS = {
|
||||
**_CROSS_ENCODER_EXAMPLE_MODELS,
|
||||
**_MULTIMODAL_EXAMPLE_MODELS,
|
||||
**_SPECULATIVE_DECODING_EXAMPLE_MODELS,
|
||||
**_FALLBACK_MODEL,
|
||||
**_TRANSFORMERS_MODELS,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -39,8 +39,8 @@ def is_transformers_impl_compatible(
|
||||
return mod.is_backend_compatible()
|
||||
|
||||
|
||||
def resolve_transformers_fallback(model_config: ModelConfig,
|
||||
architectures: list[str]):
|
||||
def resolve_transformers_arch(model_config: ModelConfig,
|
||||
architectures: list[str]):
|
||||
for i, arch in enumerate(architectures):
|
||||
if arch == "TransformersForCausalLM":
|
||||
continue
|
||||
@ -101,8 +101,7 @@ def get_model_architecture(
|
||||
for arch in architectures)
|
||||
if (not is_vllm_supported
|
||||
or model_config.model_impl == ModelImpl.TRANSFORMERS):
|
||||
architectures = resolve_transformers_fallback(model_config,
|
||||
architectures)
|
||||
architectures = resolve_transformers_arch(model_config, architectures)
|
||||
|
||||
model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
|
||||
if model_config.task == "embed":
|
||||
|
||||
@ -202,7 +202,7 @@ _SPECULATIVE_DECODING_MODELS = {
|
||||
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
||||
}
|
||||
|
||||
_FALLBACK_MODEL = {
|
||||
_TRANSFORMERS_MODELS = {
|
||||
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
|
||||
}
|
||||
# yapf: enable
|
||||
@ -213,7 +213,7 @@ _VLLM_MODELS = {
|
||||
**_CROSS_ENCODER_MODELS,
|
||||
**_MULTIMODAL_MODELS,
|
||||
**_SPECULATIVE_DECODING_MODELS,
|
||||
**_FALLBACK_MODEL,
|
||||
**_TRANSFORMERS_MODELS,
|
||||
}
|
||||
|
||||
# This variable is used as the args for subprocess.run(). We
|
||||
@ -427,7 +427,7 @@ class _ModelRegistry:
|
||||
normalized_arch = list(
|
||||
filter(lambda model: model in self.models, architectures))
|
||||
|
||||
# make sure Transformers fallback are put at the last
|
||||
# make sure Transformers backend is put at the last as a fallback
|
||||
if len(normalized_arch) != len(architectures):
|
||||
normalized_arch.append("TransformersForCausalLM")
|
||||
return normalized_arch
|
||||
|
||||
@ -401,7 +401,7 @@ class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA,
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
# FIXME(Isotr0py): Don't use any weights mapper for Transformers fallback,
|
||||
# FIXME(Isotr0py): Don't use any weights mapper for Transformers backend,
|
||||
# this makes thing complicated. We need to remove this mapper after refactor
|
||||
# `TransformersModel` in the future.
|
||||
@property
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user