mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 21:55:38 +08:00
[doc] update the code to add models (#10603)
Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
c055747867
commit
e4fbb14414
@ -38,41 +38,70 @@ For instance, vLLM's `OPT model <https://github.com/vllm-project/vllm/blob/main/
|
||||
When copying the model code, make sure to review and adhere to the code's copyright and licensing terms.
|
||||
|
||||
|
||||
2. Rewrite the :code:`forward` methods
|
||||
2. Make your code compatible with vLLM
|
||||
--------------------------------------
|
||||
|
||||
Next, you need to rewrite the :meth:`~torch.nn.Module.forward` method of your model by following these steps:
|
||||
To ensure compatibility with vLLM, your model must meet the following requirements:
|
||||
|
||||
1. Remove any unnecessary code, such as the code only used for training.
|
||||
2. Change the input parameters:
|
||||
Initialization Code
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. code-block:: diff
|
||||
All vLLM modules within the model must include a ``prefix`` argument in their constructor. This ``prefix`` is typically the full name of the module in the model's state dictionary and is crucial for:
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
- attention_mask: Optional[torch.Tensor] = None,
|
||||
- position_ids: Optional[torch.LongTensor] = None,
|
||||
- past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
- inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
- labels: Optional[torch.LongTensor] = None,
|
||||
- use_cache: Optional[bool] = None,
|
||||
- output_attentions: Optional[bool] = None,
|
||||
- output_hidden_states: Optional[bool] = None,
|
||||
- return_dict: Optional[bool] = None,
|
||||
- ) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
+ positions: torch.Tensor,
|
||||
+ kv_caches: List[torch.Tensor],
|
||||
+ attn_metadata: AttentionMetadata,
|
||||
+ ) -> Optional[SamplerOutput]:
|
||||
* Runtime support: vLLM's attention operators are registered in a model's state by their full names. Each attention operator must have a unique prefix as its layer name to avoid conflicts.
|
||||
* Non-uniform quantization support: A quantized checkpoint can selectively quantize certain layers while keeping others in full precision. By providing the ``prefix`` during initialization, vLLM can match the current layer's ``prefix`` with the quantization configuration to determine if the layer should be initialized in quantized mode.
|
||||
|
||||
1. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors.
|
||||
2. Replace the attention operation with either :code:`PagedAttention`, :code:`PagedAttentionWithRoPE`, or :code:`PagedAttentionWithALiBi` depending on the model's architecture.
|
||||
The initialization code should look like this:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from torch import nn
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.attention import Attention
|
||||
|
||||
class MyAttention(nn.Module):
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str):
|
||||
super().__init__()
|
||||
self.attn = Attention(prefix=f"{prefix}.attn")
|
||||
|
||||
class MyDecoderLayer(nn.Module):
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str):
|
||||
super().__init__()
|
||||
self.self_attn = MyAttention(prefix=f"{prefix}.self_attn")
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(
|
||||
[MyDecoderLayer(vllm_config, prefix=f"{prefix}.layers.{i}") for i in range(vllm_config.model_config.hf_config.num_hidden_layers)]
|
||||
)
|
||||
|
||||
class MyModelForCausalLM(nn.Module):
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.model = MyModel(vllm_config, prefix=f"{prefix}.model")
|
||||
|
||||
Computation Code
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
Rewrite the :meth:`~torch.nn.Module.forward` method of your model to remove any unnecessary code, such as training-specific code. Modify the input parameters to treat ``input_ids`` and ``positions`` as flattened tensors with a single batch size dimension, without a max-sequence length dimension.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
|
||||
.. note::
|
||||
Currently, vLLM supports the basic multi-head attention mechanism and its variant with rotary positional embeddings.
|
||||
If your model employs a different attention mechanism, you will need to implement a new attention layer in vLLM.
|
||||
|
||||
For reference, check out the `LLAMA model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py>`__. vLLM already supports a large number of models. It is recommended to find a model similar to yours and adapt it to your model's architecture. Check out the `vLLM models <https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models>`__ directory for more examples.
|
||||
|
||||
3. (Optional) Implement tensor parallelism and quantization support
|
||||
-------------------------------------------------------------------
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user