[doc] update the code to add models (#10603)

Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
youkaichao 2024-11-24 11:21:40 -08:00 committed by GitHub
parent c055747867
commit e4fbb14414
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -38,41 +38,70 @@ For instance, vLLM's `OPT model <https://github.com/vllm-project/vllm/blob/main/
When copying the model code, make sure to review and adhere to the code's copyright and licensing terms.
2. Rewrite the :code:`forward` methods
2. Make your code compatible with vLLM
--------------------------------------
Next, you need to rewrite the :meth:`~torch.nn.Module.forward` method of your model by following these steps:
To ensure compatibility with vLLM, your model must meet the following requirements:
1. Remove any unnecessary code, such as the code only used for training.
2. Change the input parameters:
Initialization Code
^^^^^^^^^^^^^^^^^^^
.. code-block:: diff
All vLLM modules within the model must include a ``prefix`` argument in their constructor. This ``prefix`` is typically the full name of the module in the model's state dictionary and is crucial for:
def forward(
self,
input_ids: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, CausalLMOutputWithPast]:
+ positions: torch.Tensor,
+ kv_caches: List[torch.Tensor],
+ attn_metadata: AttentionMetadata,
+ ) -> Optional[SamplerOutput]:
* Runtime support: vLLM's attention operators are registered in a model's state by their full names. Each attention operator must have a unique prefix as its layer name to avoid conflicts.
* Non-uniform quantization support: A quantized checkpoint can selectively quantize certain layers while keeping others in full precision. By providing the ``prefix`` during initialization, vLLM can match the current layer's ``prefix`` with the quantization configuration to determine if the layer should be initialized in quantized mode.
1. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors.
2. Replace the attention operation with either :code:`PagedAttention`, :code:`PagedAttentionWithRoPE`, or :code:`PagedAttentionWithALiBi` depending on the model's architecture.
The initialization code should look like this:
.. code-block:: python
from torch import nn
from vllm.config import VllmConfig
from vllm.attention import Attention
class MyAttention(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str):
super().__init__()
self.attn = Attention(prefix=f"{prefix}.attn")
class MyDecoderLayer(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str):
super().__init__()
self.self_attn = MyAttention(prefix=f"{prefix}.self_attn")
class MyModel(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str):
super().__init__()
self.layers = nn.ModuleList(
[MyDecoderLayer(vllm_config, prefix=f"{prefix}.layers.{i}") for i in range(vllm_config.model_config.hf_config.num_hidden_layers)]
)
class MyModelForCausalLM(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.model = MyModel(vllm_config, prefix=f"{prefix}.model")
Computation Code
^^^^^^^^^^^^^^^^
Rewrite the :meth:`~torch.nn.Module.forward` method of your model to remove any unnecessary code, such as training-specific code. Modify the input parameters to treat ``input_ids`` and ``positions`` as flattened tensors with a single batch size dimension, without a max-sequence length dimension.
.. code-block:: python
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
...
.. note::
Currently, vLLM supports the basic multi-head attention mechanism and its variant with rotary positional embeddings.
If your model employs a different attention mechanism, you will need to implement a new attention layer in vLLM.
For reference, check out the `LLAMA model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py>`__. vLLM already supports a large number of models. It is recommended to find a model similar to yours and adapt it to your model's architecture. Check out the `vLLM models <https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models>`__ directory for more examples.
3. (Optional) Implement tensor parallelism and quantization support
-------------------------------------------------------------------