mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 21:35:00 +08:00
[Core] enable out-of-tree model register (#3871)
This commit is contained in:
parent
e4be7d70bb
commit
95baec828f
@ -34,7 +34,10 @@ steps:
|
|||||||
command: pytest -v -s engine tokenization test_sequence.py test_config.py
|
command: pytest -v -s engine tokenization test_sequence.py test_config.py
|
||||||
|
|
||||||
- label: Entrypoints Test
|
- label: Entrypoints Test
|
||||||
command: pytest -v -s entrypoints
|
commands:
|
||||||
|
# these tests have to be separated, because each one will allocate all posible GPU memory
|
||||||
|
- pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py
|
||||||
|
- pytest -v -s entrypoints/test_server_oot_registration.py
|
||||||
|
|
||||||
- label: Examples Test
|
- label: Examples Test
|
||||||
working_dir: "/vllm-workspace/examples"
|
working_dir: "/vllm-workspace/examples"
|
||||||
|
|||||||
@ -21,6 +21,8 @@ This document provides a high-level guide on integrating a `HuggingFace Transfor
|
|||||||
Start by forking our `GitHub`_ repository and then :ref:`build it from source <build_from_source>`.
|
Start by forking our `GitHub`_ repository and then :ref:`build it from source <build_from_source>`.
|
||||||
This gives you the ability to modify the codebase and test your model.
|
This gives you the ability to modify the codebase and test your model.
|
||||||
|
|
||||||
|
.. tip::
|
||||||
|
If you don't want to fork the repository and modify vLLM's codebase, please refer to the "Out-of-Tree Model Integration" section below.
|
||||||
|
|
||||||
1. Bring your model code
|
1. Bring your model code
|
||||||
------------------------
|
------------------------
|
||||||
@ -94,3 +96,28 @@ This method should load the weights from the HuggingFace's checkpoint file and a
|
|||||||
----------------------
|
----------------------
|
||||||
|
|
||||||
Finally, include your :code:`*ForCausalLM` class in `vllm/model_executor/models/__init__.py <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/__init__.py>`_ and register it to the :code:`_MODEL_REGISTRY` in `vllm/model_executor/model_loader.py <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/model_loader.py>`_.
|
Finally, include your :code:`*ForCausalLM` class in `vllm/model_executor/models/__init__.py <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/__init__.py>`_ and register it to the :code:`_MODEL_REGISTRY` in `vllm/model_executor/model_loader.py <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/model_loader.py>`_.
|
||||||
|
|
||||||
|
6. Out-of-Tree Model Integration
|
||||||
|
--------------------------------------------
|
||||||
|
|
||||||
|
We also provide a way to integrate a model without modifying the vLLM codebase. Step 2, 3, 4 are still required, but you can skip step 1 and 5.
|
||||||
|
|
||||||
|
Just add the following lines in your code:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from vllm import ModelRegistry
|
||||||
|
from your_code import YourModelForCausalLM
|
||||||
|
ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM)
|
||||||
|
|
||||||
|
If you are running api server with `python -m vllm.entrypoints.openai.api_server args`, you can wrap the entrypoint with the following code:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from vllm import ModelRegistry
|
||||||
|
from your_code import YourModelForCausalLM
|
||||||
|
ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM)
|
||||||
|
import runpy
|
||||||
|
runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__')
|
||||||
|
|
||||||
|
Save the above code in a file and run it with `python your_file.py args`.
|
||||||
|
|||||||
66
tests/entrypoints/test_server_oot_registration.py
Normal file
66
tests/entrypoints/test_server_oot_registration.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
import multiprocessing
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from openai import OpenAI, OpenAIError
|
||||||
|
|
||||||
|
from vllm import ModelRegistry
|
||||||
|
from vllm.model_executor.models.opt import OPTForCausalLM
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.utils import get_open_port
|
||||||
|
|
||||||
|
|
||||||
|
class MyOPTForCausalLM(OPTForCausalLM):
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
# this dummy model always predicts the first token
|
||||||
|
logits = super().compute_logits(hidden_states, sampling_metadata)
|
||||||
|
logits.zero_()
|
||||||
|
logits[:, 0] += 1.0
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def server_function(port):
|
||||||
|
# register our dummy model
|
||||||
|
ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
|
||||||
|
sys.argv = ["placeholder.py"] + \
|
||||||
|
("--model facebook/opt-125m --dtype"
|
||||||
|
f" float32 --api-key token-abc123 --port {port}").split()
|
||||||
|
import runpy
|
||||||
|
runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__')
|
||||||
|
|
||||||
|
|
||||||
|
def test_oot_registration_for_api_server():
|
||||||
|
port = get_open_port()
|
||||||
|
server = multiprocessing.Process(target=server_function, args=(port, ))
|
||||||
|
server.start()
|
||||||
|
client = OpenAI(
|
||||||
|
base_url=f"http://localhost:{port}/v1",
|
||||||
|
api_key="token-abc123",
|
||||||
|
)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
completion = client.chat.completions.create(
|
||||||
|
model="facebook/opt-125m",
|
||||||
|
messages=[{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant."
|
||||||
|
}, {
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello!"
|
||||||
|
}],
|
||||||
|
temperature=0,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except OpenAIError as e:
|
||||||
|
if "Connection error" in str(e):
|
||||||
|
time.sleep(3)
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
server.kill()
|
||||||
|
generated_text = completion.choices[0].message.content
|
||||||
|
# make sure only the first token is generated
|
||||||
|
rest = generated_text.replace("<s>", "")
|
||||||
|
assert rest == ""
|
||||||
32
tests/models/test_oot_registration.py
Normal file
32
tests/models/test_oot_registration.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import LLM, ModelRegistry, SamplingParams
|
||||||
|
from vllm.model_executor.models.opt import OPTForCausalLM
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
|
||||||
|
|
||||||
|
class MyOPTForCausalLM(OPTForCausalLM):
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
# this dummy model always predicts the first token
|
||||||
|
logits = super().compute_logits(hidden_states, sampling_metadata)
|
||||||
|
logits.zero_()
|
||||||
|
logits[:, 0] += 1.0
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def test_oot_registration():
|
||||||
|
# register our dummy model
|
||||||
|
ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
|
||||||
|
prompts = ["Hello, my name is", "The text does not matter"]
|
||||||
|
sampling_params = SamplingParams(temperature=0)
|
||||||
|
llm = LLM(model="facebook/opt-125m")
|
||||||
|
first_token = llm.get_tokenizer().decode(0)
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
for output in outputs:
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
# make sure only the first token is generated
|
||||||
|
rest = generated_text.replace(first_token, "")
|
||||||
|
assert rest == ""
|
||||||
@ -5,6 +5,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
|
|||||||
from vllm.engine.llm_engine import LLMEngine
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
from vllm.engine.ray_utils import initialize_ray_cluster
|
from vllm.engine.ray_utils import initialize_ray_cluster
|
||||||
from vllm.entrypoints.llm import LLM
|
from vllm.entrypoints.llm import LLM
|
||||||
|
from vllm.model_executor.models import ModelRegistry
|
||||||
from vllm.outputs import CompletionOutput, RequestOutput
|
from vllm.outputs import CompletionOutput, RequestOutput
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
@ -12,6 +13,7 @@ __version__ = "0.4.0.post1"
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"LLM",
|
"LLM",
|
||||||
|
"ModelRegistry",
|
||||||
"SamplingParams",
|
"SamplingParams",
|
||||||
"RequestOutput",
|
"RequestOutput",
|
||||||
"CompletionOutput",
|
"CompletionOutput",
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import importlib
|
import importlib
|
||||||
from typing import List, Optional, Type
|
from typing import Dict, List, Optional, Type
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
@ -55,6 +55,10 @@ _MODELS = {
|
|||||||
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
|
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Architecture -> type.
|
||||||
|
# out of tree models
|
||||||
|
_OOT_MODELS: Dict[str, Type[nn.Module]] = {}
|
||||||
|
|
||||||
# Models not supported by ROCm.
|
# Models not supported by ROCm.
|
||||||
_ROCM_UNSUPPORTED_MODELS = []
|
_ROCM_UNSUPPORTED_MODELS = []
|
||||||
|
|
||||||
@ -74,6 +78,8 @@ class ModelRegistry:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
|
def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
|
||||||
|
if model_arch in _OOT_MODELS:
|
||||||
|
return _OOT_MODELS[model_arch]
|
||||||
if model_arch not in _MODELS:
|
if model_arch not in _MODELS:
|
||||||
return None
|
return None
|
||||||
if is_hip():
|
if is_hip():
|
||||||
@ -95,6 +101,16 @@ class ModelRegistry:
|
|||||||
def get_supported_archs() -> List[str]:
|
def get_supported_archs() -> List[str]:
|
||||||
return list(_MODELS.keys())
|
return list(_MODELS.keys())
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def register_model(model_arch: str, model_cls: Type[nn.Module]):
|
||||||
|
if model_arch in _MODELS:
|
||||||
|
logger.warning(
|
||||||
|
f"Model architecture {model_arch} is already registered, "
|
||||||
|
"and will be overwritten by the new model "
|
||||||
|
f"class {model_cls.__name__}.")
|
||||||
|
global _OOT_MODELS
|
||||||
|
_OOT_MODELS[model_arch] = model_cls
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ModelRegistry",
|
"ModelRegistry",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user