mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 07:15:01 +08:00
[Core] Support load and unload LoRA in api server (#6566)
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
2febcf2777
commit
db3bf7c991
@ -11,6 +11,5 @@ pydantic >= 2.8
|
|||||||
torch
|
torch
|
||||||
py-cpuinfo
|
py-cpuinfo
|
||||||
transformers
|
transformers
|
||||||
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
|
|
||||||
mistral_common >= 1.3.4
|
mistral_common >= 1.3.4
|
||||||
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
|
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
|
||||||
@ -107,3 +107,55 @@ The following is an example request
|
|||||||
"max_tokens": 7,
|
"max_tokens": 7,
|
||||||
"temperature": 0
|
"temperature": 0
|
||||||
}' | jq
|
}' | jq
|
||||||
|
|
||||||
|
|
||||||
|
Dynamically serving LoRA Adapters
|
||||||
|
---------------------------------
|
||||||
|
|
||||||
|
In addition to serving LoRA adapters at server startup, the vLLM server now supports dynamically loading and unloading
|
||||||
|
LoRA adapters at runtime through dedicated API endpoints. This feature can be particularly useful when the flexibility
|
||||||
|
to change models on-the-fly is needed.
|
||||||
|
|
||||||
|
Note: Enabling this feature in production environments is risky as user may participate model adapter management.
|
||||||
|
|
||||||
|
To enable dynamic LoRA loading and unloading, ensure that the environment variable `VLLM_ALLOW_RUNTIME_LORA_UPDATING`
|
||||||
|
is set to `True`. When this option is enabled, the API server will log a warning to indicate that dynamic loading is active.
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
export VLLM_ALLOW_RUNTIME_LORA_UPDATING=True
|
||||||
|
|
||||||
|
|
||||||
|
Loading a LoRA Adapter:
|
||||||
|
|
||||||
|
To dynamically load a LoRA adapter, send a POST request to the `/v1/load_lora_adapter` endpoint with the necessary
|
||||||
|
details of the adapter to be loaded. The request payload should include the name and path to the LoRA adapter.
|
||||||
|
|
||||||
|
Example request to load a LoRA adapter:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
curl -X POST http://localhost:8000/v1/load_lora_adapter \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"lora_name": "sql_adapter",
|
||||||
|
"lora_path": "/path/to/sql-lora-adapter"
|
||||||
|
}'
|
||||||
|
|
||||||
|
Upon a successful request, the API will respond with a 200 OK status code. If an error occurs, such as if the adapter
|
||||||
|
cannot be found or loaded, an appropriate error message will be returned.
|
||||||
|
|
||||||
|
Unloading a LoRA Adapter:
|
||||||
|
|
||||||
|
To unload a LoRA adapter that has been previously loaded, send a POST request to the `/v1/unload_lora_adapter` endpoint
|
||||||
|
with the name or ID of the adapter to be unloaded.
|
||||||
|
|
||||||
|
Example request to unload a LoRA adapter:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
curl -X POST http://localhost:8000/v1/unload_lora_adapter \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"lora_name": "sql_adapter"
|
||||||
|
}'
|
||||||
|
|||||||
@ -50,7 +50,7 @@ def zephyr_lora_files():
|
|||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
def test_multiple_lora_requests(llm: LLM, zephyr_lora_files):
|
def test_multiple_lora_requests(llm: LLM, zephyr_lora_files):
|
||||||
lora_request = [
|
lora_request = [
|
||||||
LoRARequest(LORA_NAME, idx + 1, zephyr_lora_files)
|
LoRARequest(LORA_NAME + str(idx), idx + 1, zephyr_lora_files)
|
||||||
for idx in range(len(PROMPTS))
|
for idx in range(len(PROMPTS))
|
||||||
]
|
]
|
||||||
# Multiple SamplingParams should be matched with each prompt
|
# Multiple SamplingParams should be matched with each prompt
|
||||||
|
|||||||
107
tests/entrypoints/openai/test_serving_engine.py
Normal file
107
tests/entrypoints/openai/test_serving_engine.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
from http import HTTPStatus
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.engine.protocol import AsyncEngineClient
|
||||||
|
from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
||||||
|
LoadLoraAdapterRequest,
|
||||||
|
UnloadLoraAdapterRequest)
|
||||||
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||||
|
|
||||||
|
MODEL_NAME = "meta-llama/Llama-2-7b"
|
||||||
|
LORA_LOADING_SUCCESS_MESSAGE = (
|
||||||
|
"Success: LoRA adapter '{lora_name}' added successfully.")
|
||||||
|
LORA_UNLOADING_SUCCESS_MESSAGE = (
|
||||||
|
"Success: LoRA adapter '{lora_name}' removed successfully.")
|
||||||
|
|
||||||
|
|
||||||
|
async def _async_serving_engine_init():
|
||||||
|
mock_engine_client = MagicMock(spec=AsyncEngineClient)
|
||||||
|
mock_model_config = MagicMock(spec=ModelConfig)
|
||||||
|
# Set the max_model_len attribute to avoid missing attribute
|
||||||
|
mock_model_config.max_model_len = 2048
|
||||||
|
|
||||||
|
serving_engine = OpenAIServing(mock_engine_client,
|
||||||
|
mock_model_config,
|
||||||
|
served_model_names=[MODEL_NAME],
|
||||||
|
lora_modules=None,
|
||||||
|
prompt_adapters=None,
|
||||||
|
request_logger=None)
|
||||||
|
return serving_engine
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_lora_adapter_success():
|
||||||
|
serving_engine = await _async_serving_engine_init()
|
||||||
|
request = LoadLoraAdapterRequest(lora_name="adapter",
|
||||||
|
lora_path="/path/to/adapter2")
|
||||||
|
response = await serving_engine.load_lora_adapter(request)
|
||||||
|
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter')
|
||||||
|
assert len(serving_engine.lora_requests) == 1
|
||||||
|
assert serving_engine.lora_requests[0].lora_name == "adapter"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_lora_adapter_missing_fields():
|
||||||
|
serving_engine = await _async_serving_engine_init()
|
||||||
|
request = LoadLoraAdapterRequest(lora_name="", lora_path="")
|
||||||
|
response = await serving_engine.load_lora_adapter(request)
|
||||||
|
assert isinstance(response, ErrorResponse)
|
||||||
|
assert response.type == "InvalidUserInput"
|
||||||
|
assert response.code == HTTPStatus.BAD_REQUEST
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_lora_adapter_duplicate():
|
||||||
|
serving_engine = await _async_serving_engine_init()
|
||||||
|
request = LoadLoraAdapterRequest(lora_name="adapter1",
|
||||||
|
lora_path="/path/to/adapter1")
|
||||||
|
response = await serving_engine.load_lora_adapter(request)
|
||||||
|
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(
|
||||||
|
lora_name='adapter1')
|
||||||
|
assert len(serving_engine.lora_requests) == 1
|
||||||
|
|
||||||
|
request = LoadLoraAdapterRequest(lora_name="adapter1",
|
||||||
|
lora_path="/path/to/adapter1")
|
||||||
|
response = await serving_engine.load_lora_adapter(request)
|
||||||
|
assert isinstance(response, ErrorResponse)
|
||||||
|
assert response.type == "InvalidUserInput"
|
||||||
|
assert response.code == HTTPStatus.BAD_REQUEST
|
||||||
|
assert len(serving_engine.lora_requests) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unload_lora_adapter_success():
|
||||||
|
serving_engine = await _async_serving_engine_init()
|
||||||
|
request = LoadLoraAdapterRequest(lora_name="adapter1",
|
||||||
|
lora_path="/path/to/adapter1")
|
||||||
|
response = await serving_engine.load_lora_adapter(request)
|
||||||
|
assert len(serving_engine.lora_requests) == 1
|
||||||
|
|
||||||
|
request = UnloadLoraAdapterRequest(lora_name="adapter1")
|
||||||
|
response = await serving_engine.unload_lora_adapter(request)
|
||||||
|
assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format(
|
||||||
|
lora_name='adapter1')
|
||||||
|
assert len(serving_engine.lora_requests) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unload_lora_adapter_missing_fields():
|
||||||
|
serving_engine = await _async_serving_engine_init()
|
||||||
|
request = UnloadLoraAdapterRequest(lora_name="", lora_int_id=None)
|
||||||
|
response = await serving_engine.unload_lora_adapter(request)
|
||||||
|
assert isinstance(response, ErrorResponse)
|
||||||
|
assert response.type == "InvalidUserInput"
|
||||||
|
assert response.code == HTTPStatus.BAD_REQUEST
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unload_lora_adapter_not_found():
|
||||||
|
serving_engine = await _async_serving_engine_init()
|
||||||
|
request = UnloadLoraAdapterRequest(lora_name="nonexistent_adapter")
|
||||||
|
response = await serving_engine.unload_lora_adapter(request)
|
||||||
|
assert isinstance(response, ErrorResponse)
|
||||||
|
assert response.type == "InvalidUserInput"
|
||||||
|
assert response.code == HTTPStatus.BAD_REQUEST
|
||||||
@ -35,11 +35,13 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
|||||||
DetokenizeResponse,
|
DetokenizeResponse,
|
||||||
EmbeddingRequest,
|
EmbeddingRequest,
|
||||||
EmbeddingResponse, ErrorResponse,
|
EmbeddingResponse, ErrorResponse,
|
||||||
|
LoadLoraAdapterRequest,
|
||||||
TokenizeRequest,
|
TokenizeRequest,
|
||||||
TokenizeResponse)
|
TokenizeResponse,
|
||||||
# yapf: enable
|
UnloadLoraAdapterRequest)
|
||||||
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
|
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
|
||||||
from vllm.entrypoints.openai.rpc.server import run_rpc_server
|
from vllm.entrypoints.openai.rpc.server import run_rpc_server
|
||||||
|
# yapf: enable
|
||||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||||
@ -343,6 +345,40 @@ if envs.VLLM_TORCH_PROFILER_DIR:
|
|||||||
return Response(status_code=200)
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
||||||
|
logger.warning(
|
||||||
|
"Lora dynamic loading & unloading is enabled in the API server. "
|
||||||
|
"This should ONLY be used for local development!")
|
||||||
|
|
||||||
|
@router.post("/v1/load_lora_adapter")
|
||||||
|
async def load_lora_adapter(request: LoadLoraAdapterRequest):
|
||||||
|
response = await openai_serving_chat.load_lora_adapter(request)
|
||||||
|
if isinstance(response, ErrorResponse):
|
||||||
|
return JSONResponse(content=response.model_dump(),
|
||||||
|
status_code=response.code)
|
||||||
|
|
||||||
|
response = await openai_serving_completion.load_lora_adapter(request)
|
||||||
|
if isinstance(response, ErrorResponse):
|
||||||
|
return JSONResponse(content=response.model_dump(),
|
||||||
|
status_code=response.code)
|
||||||
|
|
||||||
|
return Response(status_code=200, content=response)
|
||||||
|
|
||||||
|
@router.post("/v1/unload_lora_adapter")
|
||||||
|
async def unload_lora_adapter(request: UnloadLoraAdapterRequest):
|
||||||
|
response = await openai_serving_chat.unload_lora_adapter(request)
|
||||||
|
if isinstance(response, ErrorResponse):
|
||||||
|
return JSONResponse(content=response.model_dump(),
|
||||||
|
status_code=response.code)
|
||||||
|
|
||||||
|
response = await openai_serving_completion.unload_lora_adapter(request)
|
||||||
|
if isinstance(response, ErrorResponse):
|
||||||
|
return JSONResponse(content=response.model_dump(),
|
||||||
|
status_code=response.code)
|
||||||
|
|
||||||
|
return Response(status_code=200, content=response)
|
||||||
|
|
||||||
|
|
||||||
def build_app(args: Namespace) -> FastAPI:
|
def build_app(args: Namespace) -> FastAPI:
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
app.include_router(router)
|
app.include_router(router)
|
||||||
|
|||||||
@ -878,3 +878,13 @@ class DetokenizeRequest(OpenAIBaseModel):
|
|||||||
|
|
||||||
class DetokenizeResponse(OpenAIBaseModel):
|
class DetokenizeResponse(OpenAIBaseModel):
|
||||||
prompt: str
|
prompt: str
|
||||||
|
|
||||||
|
|
||||||
|
class LoadLoraAdapterRequest(BaseModel):
|
||||||
|
lora_name: str
|
||||||
|
lora_path: str
|
||||||
|
|
||||||
|
|
||||||
|
class UnloadLoraAdapterRequest(BaseModel):
|
||||||
|
lora_name: str
|
||||||
|
lora_int_id: Optional[int] = Field(default=None)
|
||||||
|
|||||||
@ -16,11 +16,13 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
|||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
DetokenizeRequest,
|
DetokenizeRequest,
|
||||||
EmbeddingRequest, ErrorResponse,
|
EmbeddingRequest, ErrorResponse,
|
||||||
|
LoadLoraAdapterRequest,
|
||||||
ModelCard, ModelList,
|
ModelCard, ModelList,
|
||||||
ModelPermission,
|
ModelPermission,
|
||||||
TokenizeChatRequest,
|
TokenizeChatRequest,
|
||||||
TokenizeCompletionRequest,
|
TokenizeCompletionRequest,
|
||||||
TokenizeRequest)
|
TokenizeRequest,
|
||||||
|
UnloadLoraAdapterRequest)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.inputs.parse import parse_and_batch_prompt
|
from vllm.inputs.parse import parse_and_batch_prompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -32,6 +34,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
|
|||||||
from vllm.sampling_params import LogitsProcessor, SamplingParams
|
from vllm.sampling_params import LogitsProcessor, SamplingParams
|
||||||
from vllm.sequence import Logprob
|
from vllm.sequence import Logprob
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
from vllm.utils import AtomicCounter
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -78,6 +81,7 @@ class OpenAIServing:
|
|||||||
|
|
||||||
self.served_model_names = served_model_names
|
self.served_model_names = served_model_names
|
||||||
|
|
||||||
|
self.lora_id_counter = AtomicCounter(0)
|
||||||
self.lora_requests = []
|
self.lora_requests = []
|
||||||
if lora_modules is not None:
|
if lora_modules is not None:
|
||||||
self.lora_requests = [
|
self.lora_requests = [
|
||||||
@ -403,3 +407,76 @@ class OpenAIServing:
|
|||||||
if logprob.decoded_token is not None:
|
if logprob.decoded_token is not None:
|
||||||
return logprob.decoded_token
|
return logprob.decoded_token
|
||||||
return tokenizer.decode(token_id)
|
return tokenizer.decode(token_id)
|
||||||
|
|
||||||
|
async def _check_load_lora_adapter_request(
|
||||||
|
self, request: LoadLoraAdapterRequest) -> Optional[ErrorResponse]:
|
||||||
|
# Check if both 'lora_name' and 'lora_path' are provided
|
||||||
|
if not request.lora_name or not request.lora_path:
|
||||||
|
return self.create_error_response(
|
||||||
|
message="Both 'lora_name' and 'lora_path' must be provided.",
|
||||||
|
err_type="InvalidUserInput",
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
|
# Check if the lora adapter with the given name already exists
|
||||||
|
if any(lora_request.lora_name == request.lora_name
|
||||||
|
for lora_request in self.lora_requests):
|
||||||
|
return self.create_error_response(
|
||||||
|
message=
|
||||||
|
f"The lora adapter '{request.lora_name}' has already been"
|
||||||
|
"loaded.",
|
||||||
|
err_type="InvalidUserInput",
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _check_unload_lora_adapter_request(
|
||||||
|
self,
|
||||||
|
request: UnloadLoraAdapterRequest) -> Optional[ErrorResponse]:
|
||||||
|
# Check if either 'lora_name' or 'lora_int_id' is provided
|
||||||
|
if not request.lora_name and not request.lora_int_id:
|
||||||
|
return self.create_error_response(
|
||||||
|
message=
|
||||||
|
"either 'lora_name' and 'lora_int_id' needs to be provided.",
|
||||||
|
err_type="InvalidUserInput",
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
|
# Check if the lora adapter with the given name exists
|
||||||
|
if not any(lora_request.lora_name == request.lora_name
|
||||||
|
for lora_request in self.lora_requests):
|
||||||
|
return self.create_error_response(
|
||||||
|
message=
|
||||||
|
f"The lora adapter '{request.lora_name}' cannot be found.",
|
||||||
|
err_type="InvalidUserInput",
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def load_lora_adapter(
|
||||||
|
self,
|
||||||
|
request: LoadLoraAdapterRequest) -> Union[ErrorResponse, str]:
|
||||||
|
error_check_ret = await self._check_load_lora_adapter_request(request)
|
||||||
|
if error_check_ret is not None:
|
||||||
|
return error_check_ret
|
||||||
|
|
||||||
|
lora_name, lora_path = request.lora_name, request.lora_path
|
||||||
|
unique_id = self.lora_id_counter.inc(1)
|
||||||
|
self.lora_requests.append(
|
||||||
|
LoRARequest(lora_name=lora_name,
|
||||||
|
lora_int_id=unique_id,
|
||||||
|
lora_path=lora_path))
|
||||||
|
return f"Success: LoRA adapter '{lora_name}' added successfully."
|
||||||
|
|
||||||
|
async def unload_lora_adapter(
|
||||||
|
self,
|
||||||
|
request: UnloadLoraAdapterRequest) -> Union[ErrorResponse, str]:
|
||||||
|
error_check_ret = await self._check_unload_lora_adapter_request(request
|
||||||
|
)
|
||||||
|
if error_check_ret is not None:
|
||||||
|
return error_check_ret
|
||||||
|
|
||||||
|
lora_name = request.lora_name
|
||||||
|
self.lora_requests = [
|
||||||
|
lora_request for lora_request in self.lora_requests
|
||||||
|
if lora_request.lora_name != lora_name
|
||||||
|
]
|
||||||
|
return f"Success: LoRA adapter '{lora_name}' removed successfully."
|
||||||
|
|||||||
@ -61,6 +61,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_ALLOW_ENGINE_USE_RAY: bool = False
|
VLLM_ALLOW_ENGINE_USE_RAY: bool = False
|
||||||
VLLM_PLUGINS: Optional[List[str]] = None
|
VLLM_PLUGINS: Optional[List[str]] = None
|
||||||
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
|
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
|
||||||
|
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
|
||||||
|
|
||||||
|
|
||||||
def get_default_cache_root():
|
def get_default_cache_root():
|
||||||
@ -409,6 +410,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
# If set, vLLM will use Triton implementations of AWQ.
|
# If set, vLLM will use Triton implementations of AWQ.
|
||||||
"VLLM_USE_TRITON_AWQ":
|
"VLLM_USE_TRITON_AWQ":
|
||||||
lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))),
|
lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))),
|
||||||
|
|
||||||
|
# If set, allow loading or unloading lora adapters in runtime,
|
||||||
|
"VLLM_ALLOW_RUNTIME_LORA_UPDATING":
|
||||||
|
lambda:
|
||||||
|
(os.environ.get("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "0").strip().lower() in
|
||||||
|
("1", "true")),
|
||||||
}
|
}
|
||||||
|
|
||||||
# end-env-vars-definition
|
# end-env-vars-definition
|
||||||
|
|||||||
@ -28,7 +28,6 @@ class LoRARequest(
|
|||||||
lora_path: str = ""
|
lora_path: str = ""
|
||||||
lora_local_path: Optional[str] = msgspec.field(default=None)
|
lora_local_path: Optional[str] = msgspec.field(default=None)
|
||||||
long_lora_max_len: Optional[int] = None
|
long_lora_max_len: Optional[int] = None
|
||||||
__hash__ = AdapterRequest.__hash__
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if 'lora_local_path' in self.__struct_fields__:
|
if 'lora_local_path' in self.__struct_fields__:
|
||||||
@ -75,3 +74,21 @@ class LoRARequest(
|
|||||||
DeprecationWarning,
|
DeprecationWarning,
|
||||||
stacklevel=2)
|
stacklevel=2)
|
||||||
self.lora_path = value
|
self.lora_path = value
|
||||||
|
|
||||||
|
def __eq__(self, value: object) -> bool:
|
||||||
|
"""
|
||||||
|
Overrides the equality method to compare LoRARequest
|
||||||
|
instances based on lora_name. This allows for identification
|
||||||
|
and comparison lora adapter across engines.
|
||||||
|
"""
|
||||||
|
return isinstance(value,
|
||||||
|
self.__class__) and self.lora_name == value.lora_name
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
"""
|
||||||
|
Overrides the hash method to hash LoRARequest instances
|
||||||
|
based on lora_name. This ensures that LoRARequest instances
|
||||||
|
can be used in hash-based collections such as sets and dictionaries,
|
||||||
|
identified by their names across engines.
|
||||||
|
"""
|
||||||
|
return hash(self.lora_name)
|
||||||
|
|||||||
@ -1224,3 +1224,28 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
|
|||||||
def supports_dynamo() -> bool:
|
def supports_dynamo() -> bool:
|
||||||
base_torch_version = Version(Version(torch.__version__).base_version)
|
base_torch_version = Version(Version(torch.__version__).base_version)
|
||||||
return base_torch_version >= Version("2.4.0")
|
return base_torch_version >= Version("2.4.0")
|
||||||
|
|
||||||
|
|
||||||
|
class AtomicCounter:
|
||||||
|
"""An atomic, thread-safe counter"""
|
||||||
|
|
||||||
|
def __init__(self, initial=0):
|
||||||
|
"""Initialize a new atomic counter to given initial value"""
|
||||||
|
self._value = initial
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
def inc(self, num=1):
|
||||||
|
"""Atomically increment the counter by num and return the new value"""
|
||||||
|
with self._lock:
|
||||||
|
self._value += num
|
||||||
|
return self._value
|
||||||
|
|
||||||
|
def dec(self, num=1):
|
||||||
|
"""Atomically decrement the counter by num and return the new value"""
|
||||||
|
with self._lock:
|
||||||
|
self._value -= num
|
||||||
|
return self._value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def value(self):
|
||||||
|
return self._value
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user