[Misc] Improve LoRA spelling (#13831)

This commit is contained in:
Jee Jee Li 2025-02-26 15:43:01 +08:00 committed by GitHub
parent e206b54331
commit 5157338ed9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 80 additions and 80 deletions

View File

@ -89,7 +89,7 @@ def make_prompt_lora_mapping(num_prompts: int, num_active_loras: int,
sort_by_lora_id: bool,
device: str) -> torch.Tensor:
"""
All prompts are mapped to a Lora ID in range [0, num_active_loras).
All prompts are mapped to a LoRA ID in range [0, num_active_loras).
where 0 refers to first lora, 1 refers to second lora and so on.
"""
assert num_active_loras > 0

View File

@ -170,7 +170,7 @@ Now, you can specify a base_model_name alongside the name and path using JSON fo
To provide the backward compatibility support, you can still use the old key-value format (name=path), but the `base_model_name` will remain unspecified in that case.
## Lora model lineage in model card
## LoRA model lineage in model card
The new format of `--lora-modules` is mainly to support the display of parent model information in the model card. Here's an explanation of how your current response supports this:

View File

@ -491,7 +491,7 @@ def test_prefill_schedule_max_lora():
lora_path="abc"))
scheduler.add_seq_group(seq_group)
# Add two more requests to verify lora is prioritized.
# 0: Lora, 1: Lora, 2: regular, 3: regular
# 0: LoRA, 1: LoRA, 2: regular, 3: regular
# In the first iteration, index 0, 2 is scheduled.
# If a request is not scheduled because it hits max lora, it is
# prioritized. Verify that.

View File

@ -26,7 +26,7 @@ def serve_parser():
return make_arg_parser(parser)
### Tests for Lora module parsing
### Tests for LoRA module parsing
def test_valid_key_value_format(serve_parser):
# Test old format: name=path
args = serve_parser.parse_args([

View File

@ -8,8 +8,8 @@ import pytest
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.protocol import (ErrorResponse,
LoadLoraAdapterRequest,
UnloadLoraAdapterRequest)
LoadLoRAAdapterRequest,
UnloadLoRAAdapterRequest)
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels)
from vllm.lora.request import LoRARequest
@ -51,7 +51,7 @@ async def test_serving_model_name():
@pytest.mark.asyncio
async def test_load_lora_adapter_success():
serving_models = await _async_serving_models_init()
request = LoadLoraAdapterRequest(lora_name="adapter",
request = LoadLoRAAdapterRequest(lora_name="adapter",
lora_path="/path/to/adapter2")
response = await serving_models.load_lora_adapter(request)
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter')
@ -62,7 +62,7 @@ async def test_load_lora_adapter_success():
@pytest.mark.asyncio
async def test_load_lora_adapter_missing_fields():
serving_models = await _async_serving_models_init()
request = LoadLoraAdapterRequest(lora_name="", lora_path="")
request = LoadLoRAAdapterRequest(lora_name="", lora_path="")
response = await serving_models.load_lora_adapter(request)
assert isinstance(response, ErrorResponse)
assert response.type == "InvalidUserInput"
@ -72,14 +72,14 @@ async def test_load_lora_adapter_missing_fields():
@pytest.mark.asyncio
async def test_load_lora_adapter_duplicate():
serving_models = await _async_serving_models_init()
request = LoadLoraAdapterRequest(lora_name="adapter1",
request = LoadLoRAAdapterRequest(lora_name="adapter1",
lora_path="/path/to/adapter1")
response = await serving_models.load_lora_adapter(request)
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(
lora_name='adapter1')
assert len(serving_models.lora_requests) == 1
request = LoadLoraAdapterRequest(lora_name="adapter1",
request = LoadLoRAAdapterRequest(lora_name="adapter1",
lora_path="/path/to/adapter1")
response = await serving_models.load_lora_adapter(request)
assert isinstance(response, ErrorResponse)
@ -91,12 +91,12 @@ async def test_load_lora_adapter_duplicate():
@pytest.mark.asyncio
async def test_unload_lora_adapter_success():
serving_models = await _async_serving_models_init()
request = LoadLoraAdapterRequest(lora_name="adapter1",
request = LoadLoRAAdapterRequest(lora_name="adapter1",
lora_path="/path/to/adapter1")
response = await serving_models.load_lora_adapter(request)
assert len(serving_models.lora_requests) == 1
request = UnloadLoraAdapterRequest(lora_name="adapter1")
request = UnloadLoRAAdapterRequest(lora_name="adapter1")
response = await serving_models.unload_lora_adapter(request)
assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format(
lora_name='adapter1')
@ -106,7 +106,7 @@ async def test_unload_lora_adapter_success():
@pytest.mark.asyncio
async def test_unload_lora_adapter_missing_fields():
serving_models = await _async_serving_models_init()
request = UnloadLoraAdapterRequest(lora_name="", lora_int_id=None)
request = UnloadLoRAAdapterRequest(lora_name="", lora_int_id=None)
response = await serving_models.unload_lora_adapter(request)
assert isinstance(response, ErrorResponse)
assert response.type == "InvalidUserInput"
@ -116,7 +116,7 @@ async def test_unload_lora_adapter_missing_fields():
@pytest.mark.asyncio
async def test_unload_lora_adapter_not_found():
serving_models = await _async_serving_models_init()
request = UnloadLoraAdapterRequest(lora_name="nonexistent_adapter")
request = UnloadLoRAAdapterRequest(lora_name="nonexistent_adapter")
response = await serving_models.unload_lora_adapter(request)
assert isinstance(response, ErrorResponse)
assert response.type == "NotFoundError"

View File

@ -14,16 +14,16 @@ from vllm.config import LoRAConfig
from vllm.lora.fully_sharded_layers import (
ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora,
MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLoRA,
RowParallelLinearWithShardedLoRA)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
LinearScalingRotaryEmbeddingWithLora,
LinearScalingRotaryEmbeddingWithLoRA,
LogitsProcessorWithLoRA, LoRAMapping,
MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora,
QKVParallelLinearWithLora,
MergedQKVParallelLinearWithLoRA,
QKVParallelLinearWithLoRA,
ReplicatedLinearWithLoRA,
RowParallelLinearWithLoRA,
VocabParallelEmbeddingWithLoRA)
@ -866,9 +866,9 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
bias=False,
params_dtype=torch.float16)
linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = (MergedQKVParallelLinearWithLora(linear)
lora_linear = (MergedQKVParallelLinearWithLoRA(linear)
if not fully_shard else
MergedQKVParallelLinearWithShardedLora(linear))
MergedQKVParallelLinearWithShardedLoRA(linear))
else:
linear = QKVParallelLinear(4096,
64,
@ -876,9 +876,9 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
bias=False,
params_dtype=torch.float16)
linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = QKVParallelLinearWithLora(
lora_linear = QKVParallelLinearWithLoRA(
linear
) if not fully_shard else QKVParallelLinearWithShardedLora(linear)
) if not fully_shard else QKVParallelLinearWithShardedLoRA(linear)
@dataclass
class FakeConfig:
@ -1024,7 +1024,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
base,
is_neox_style,
)
lora_rope = LinearScalingRotaryEmbeddingWithLora(rope)
lora_rope = LinearScalingRotaryEmbeddingWithLoRA(rope)
lora_rope.set_mapping(punica_wrapper)
lora_rope.create_lora_weights(max_loras, lora_config)
linear_rope = get_rope(head_size, rotary_dim, max_position, base,

View File

@ -8,7 +8,7 @@ import pytest
import vllm
from vllm import SamplingParams
from vllm.lora.layers import LinearScalingRotaryEmbeddingWithLora
from vllm.lora.layers import LinearScalingRotaryEmbeddingWithLoRA
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.rotary_embedding import (
LinearScalingRotaryEmbedding)
@ -151,7 +151,7 @@ def test_rotary_emb_replaced(dist_init):
if "rotary_emb" in module_name:
if "base_layer" not in module_name:
rotary_emb_count += 1
assert isinstance(module, LinearScalingRotaryEmbeddingWithLora)
assert isinstance(module, LinearScalingRotaryEmbeddingWithLoRA)
else:
assert isinstance(module, LinearScalingRotaryEmbedding)
# Llama 2 has 32 layers.

View File

@ -1629,7 +1629,7 @@ class LLMEngine:
max_tokens_requests: List[int] = []
finished_reason_requests: List[str] = []
# Lora requests
# LoRA requests
running_lora_adapters = dict(
collectionsCounter([
running_request.lora_request.lora_name

View File

@ -53,7 +53,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
EmbeddingResponse,
EmbeddingResponseData,
ErrorResponse,
LoadLoraAdapterRequest,
LoadLoRAAdapterRequest,
PoolingChatRequest,
PoolingCompletionRequest,
PoolingRequest, PoolingResponse,
@ -63,7 +63,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
TokenizeResponse,
TranscriptionRequest,
TranscriptionResponse,
UnloadLoraAdapterRequest)
UnloadLoRAAdapterRequest)
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
@ -690,12 +690,12 @@ if envs.VLLM_TORCH_PROFILER_DIR:
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
logger.warning(
"Lora dynamic loading & unloading is enabled in the API server. "
"LoRA dynamic loading & unloading is enabled in the API server. "
"This should ONLY be used for local development!")
@router.post("/v1/load_lora_adapter",
dependencies=[Depends(validate_json_request)])
async def load_lora_adapter(request: LoadLoraAdapterRequest,
async def load_lora_adapter(request: LoadLoRAAdapterRequest,
raw_request: Request):
handler = models(raw_request)
response = await handler.load_lora_adapter(request)
@ -707,7 +707,7 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
@router.post("/v1/unload_lora_adapter",
dependencies=[Depends(validate_json_request)])
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
async def unload_lora_adapter(request: UnloadLoRAAdapterRequest,
raw_request: Request):
handler = models(raw_request)
response = await handler.unload_lora_adapter(request)

View File

@ -1431,12 +1431,12 @@ class DetokenizeResponse(OpenAIBaseModel):
prompt: str
class LoadLoraAdapterRequest(BaseModel):
class LoadLoRAAdapterRequest(BaseModel):
lora_name: str
lora_path: str
class UnloadLoraAdapterRequest(BaseModel):
class UnloadLoRAAdapterRequest(BaseModel):
lora_name: str
lora_int_id: Optional[int] = Field(default=None)

View File

@ -9,10 +9,10 @@ from typing import List, Optional, Union
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.protocol import (ErrorResponse,
LoadLoraAdapterRequest,
LoadLoRAAdapterRequest,
ModelCard, ModelList,
ModelPermission,
UnloadLoraAdapterRequest)
UnloadLoRAAdapterRequest)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
@ -88,7 +88,7 @@ class OpenAIServingModels:
if self.static_lora_modules is None:
return
for lora in self.static_lora_modules:
load_request = LoadLoraAdapterRequest(lora_path=lora.path,
load_request = LoadLoRAAdapterRequest(lora_path=lora.path,
lora_name=lora.name)
load_result = await self.load_lora_adapter(
request=load_request, base_model_name=lora.base_model_name)
@ -140,7 +140,7 @@ class OpenAIServingModels:
async def load_lora_adapter(
self,
request: LoadLoraAdapterRequest,
request: LoadLoRAAdapterRequest,
base_model_name: Optional[str] = None
) -> Union[ErrorResponse, str]:
error_check_ret = await self._check_load_lora_adapter_request(request)
@ -177,7 +177,7 @@ class OpenAIServingModels:
async def unload_lora_adapter(
self,
request: UnloadLoraAdapterRequest) -> Union[ErrorResponse, str]:
request: UnloadLoRAAdapterRequest) -> Union[ErrorResponse, str]:
error_check_ret = await self._check_unload_lora_adapter_request(request
)
if error_check_ret is not None:
@ -192,7 +192,7 @@ class OpenAIServingModels:
return f"Success: LoRA adapter '{lora_name}' removed successfully."
async def _check_load_lora_adapter_request(
self, request: LoadLoraAdapterRequest) -> Optional[ErrorResponse]:
self, request: LoadLoRAAdapterRequest) -> Optional[ErrorResponse]:
# Check if both 'lora_name' and 'lora_path' are provided
if not request.lora_name or not request.lora_path:
return create_error_response(
@ -214,7 +214,7 @@ class OpenAIServingModels:
async def _check_unload_lora_adapter_request(
self,
request: UnloadLoraAdapterRequest) -> Optional[ErrorResponse]:
request: UnloadLoRAAdapterRequest) -> Optional[ErrorResponse]:
# Check if either 'lora_name' or 'lora_int_id' is provided
if not request.lora_name and not request.lora_int_id:
return create_error_response(

View File

@ -13,8 +13,8 @@ from vllm.distributed.communication_op import (
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora,
QKVParallelLinearWithLora,
MergedQKVParallelLinearWithLoRA,
QKVParallelLinearWithLoRA,
RowParallelLinearWithLoRA)
if TYPE_CHECKING:
@ -167,9 +167,9 @@ class MergedColumnParallelLinearWithShardedLoRA(
)
class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora):
class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA):
"""
Differs from QKVParallelLinearWithLora by slicing the
Differs from QKVParallelLinearWithLoRA by slicing the
LoRA A's also.
Based on S-LoRA, slicing happens along the rank dim.
@ -202,9 +202,9 @@ class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora):
)
class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
"""
Differs from MergedQKVParallelLinearWithLora by slicing the
Differs from MergedQKVParallelLinearWithLoRA by slicing the
LoRA A's also.
Based on S-LoRA, slicing happens along the rank dim.

View File

@ -363,7 +363,7 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
embeddings_tensor: Optional[torch.Tensor],
lora_bias: Optional[torch.Tensor] = None,
):
# Except for QKVParallelLinearWithLora and
# Except for QKVParallelLinearWithLoRA and
# MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
# store weights in a tuple of size 1. These two layers will
# override this function.
@ -686,7 +686,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
and len(packed_modules_list) == 2)
class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
"""
ColumnParallelLinear layer that is specifically designed for
qkv_proj. Certain models, such as chatglm3 and baichuan-7b,
@ -754,7 +754,7 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
packed_modules_list) == 1
class MergedQKVParallelLinearWithLora(MergedColumnParallelLinearWithLoRA):
class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
"""MergedColumnParallelLinear layer that is composed of 3 sublayers (slices)
packed together in qkv proj fashion
(q_proj + k_proj + v_proj -> qkv_proj).
@ -1120,7 +1120,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
return False
class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA):
class LinearScalingRotaryEmbeddingWithLoRA(BaseLayerWithLoRA):
"""Implements RoPE-scaled embeddings with linear scaling for
multiple LoRA adapters with a specialized kernel.

View File

@ -20,7 +20,7 @@ from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
from vllm.config import LoRAConfig
from vllm.logger import init_logger
from vllm.lora.layers import (BaseLayerWithLoRA,
LinearScalingRotaryEmbeddingWithLora,
LinearScalingRotaryEmbeddingWithLoRA,
LoRAMapping)
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.peft_helper import PEFTHelper
@ -201,7 +201,7 @@ class LoRAModel(AdapterModel):
expected_lora_modules: Name of modules that are expected to be
replaced by lora.
peft_helper: Loaded lora configuration information.
lora_model_id: Lora model id. If not given, automatically set by
lora_model_id: LoRA model id. If not given, automatically set by
a global counter.
device: Device where the lora model is loaded.
dtype: dtype of the lora model weights.
@ -480,9 +480,9 @@ class LoRAModelManager(AdapterModelManager):
from_layer(module, self.lora_slots, self.lora_config,
packed_moduled_lst, self.model.config))
# LinearScalingRotaryEmbeddingWithLora is used to handle
# LinearScalingRotaryEmbeddingWithLoRA is used to handle
# long context lora. Register relevant metadata.
if isinstance(new_module, LinearScalingRotaryEmbeddingWithLora):
if isinstance(new_module, LinearScalingRotaryEmbeddingWithLoRA):
self.long_lora_context = LongContextLoRAContext(
new_module.scaling_factors, new_module.rotary_dim)
self.scaling_factor_to_offset = \
@ -527,7 +527,7 @@ class LoRAModelManager(AdapterModelManager):
bias_enabled = self.lora_config.bias_enabled
if (not self._match_target_modules(module_name)
or not isinstance(module, BaseLayerWithLoRA)
or isinstance(module, LinearScalingRotaryEmbeddingWithLora)
or isinstance(module, LinearScalingRotaryEmbeddingWithLoRA)
or self._filter_unsupported_mm_module(module_name)):
continue
parts = module_name.split(".")

View File

@ -42,7 +42,7 @@ class PEFTHelper:
def _validate_features(self) -> List[str]:
"""
Check if there are any unsupported Lora features.
Check if there are any unsupported LoRA features.
"""
error_msg = []
if self.modules_to_save:

View File

@ -314,7 +314,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
def long_lora_indices(self) -> torch.Tensor:
"""
This property provides access to the indices used for long context
lora, specifically for LinearScalingRotaryEmbeddingWithLora.
lora, specifically for LinearScalingRotaryEmbeddingWithLoRA.
"""
long_lora_len = self.indices_len[4]
return self._long_lora_indices[:long_lora_len]

View File

@ -15,17 +15,17 @@ from vllm.logger import init_logger
from vllm.lora.fully_sharded_layers import (
ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora,
MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLoRA,
RowParallelLinearWithShardedLoRA)
# being imported for _all_lora_classes below
# yapf conflicts with isort for this block
# yapf: disable
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
LinearScalingRotaryEmbeddingWithLora,
LinearScalingRotaryEmbeddingWithLoRA,
LogitsProcessorWithLoRA,
MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora,
QKVParallelLinearWithLora,
MergedQKVParallelLinearWithLoRA,
QKVParallelLinearWithLoRA,
ReplicatedLinearWithLoRA,
RowParallelLinearWithLoRA,
VocabParallelEmbeddingWithLoRA)
@ -41,17 +41,17 @@ _all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
VocabParallelEmbeddingWithLoRA,
ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA,
QKVParallelLinearWithLora,
MergedQKVParallelLinearWithLora,
QKVParallelLinearWithLoRA,
MergedQKVParallelLinearWithLoRA,
RowParallelLinearWithLoRA,
ReplicatedLinearWithLoRA,
LogitsProcessorWithLoRA,
ColumnParallelLinearWithShardedLoRA,
QKVParallelLinearWithShardedLora,
QKVParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLora,
MergedQKVParallelLinearWithShardedLoRA,
RowParallelLinearWithShardedLoRA,
LinearScalingRotaryEmbeddingWithLora,
LinearScalingRotaryEmbeddingWithLoRA,
}

View File

@ -6,10 +6,10 @@ from typing import List, Optional, Set, Tuple
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.interfaces import SpeculativeProposer
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
from vllm.worker.worker_base import LoRANotSupportedWorkerBase
class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer):
class ProposerWorkerBase(LoRANotSupportedWorkerBase, SpeculativeProposer):
"""Interface for proposer workers"""
@abstractmethod

View File

@ -47,7 +47,7 @@ from vllm.spec_decode.util import (Timer, create_logprobs_output,
get_sampled_token_logprobs, nvtx_range,
split_batch_by_proposal_len)
from vllm.utils import resolve_obj_by_qualname
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase
logger = init_logger(__name__)
@ -118,7 +118,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
# Reminder: Please update docs/source/features/compatibility_matrix.md
# If the feature combo become valid
class SpecDecodeWorker(LoraNotSupportedWorkerBase):
class SpecDecodeWorker(LoRANotSupportedWorkerBase):
"""Worker which implements speculative decoding.
Speculative decoding reduces decoding per-token latency by using a proposal

View File

@ -21,7 +21,7 @@ ARCTIC_PRETRAINED_CONFIG_ARCHIVE_MAP = {
@dataclass
class ArcticLoraConfig:
class ArcticLoRAConfig:
lora_r: int = 64
lora_alpha: float = 16
shard_base_weights: bool = False

View File

@ -13,11 +13,11 @@ from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.worker.neuron_model_runner import NeuronModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerBase,
LoRANotSupportedWorkerBase, WorkerBase,
WorkerInput)
class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
class NeuronWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
"""A worker class that executes the model on a group of neuron cores.
"""

View File

@ -24,7 +24,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.utils import bind_kv_cache
from vllm.worker.openvino_model_runner import OpenVINOModelRunner
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase
logger = init_logger(__name__)
@ -203,7 +203,7 @@ class OpenVINOCacheEngine:
return dtype_size * total
class OpenVINOWorker(LoraNotSupportedWorkerBase):
class OpenVINOWorker(LoRANotSupportedWorkerBase):
"""A worker class that executes the model on OpenVINO backend.
Each worker is associated with a single OpenVINO device. The worker is

View File

@ -17,13 +17,13 @@ from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache, get_dtype_size
from vllm.worker.tpu_model_runner import ExecutionMode, TPUModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerBase,
LoRANotSupportedWorkerBase, WorkerBase,
WorkerInput)
logger = init_logger(__name__)
class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
class TPUWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
def __init__(
self,

View File

@ -189,7 +189,7 @@ class DelegateWorkerBase(WorkerBase):
return getattr(self.worker, attr)
class LoraNotSupportedWorkerBase(WorkerBase):
class LoRANotSupportedWorkerBase(WorkerBase):
"""Partial implementation of WorkerBase that raises exceptions when LoRA
methods are invoked.
"""

View File

@ -18,13 +18,13 @@ from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.worker import Worker
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase
from vllm.worker.xpu_model_runner import XPUModelRunner
logger = init_logger(__name__)
class XPUWorker(LoraNotSupportedWorkerBase, Worker):
class XPUWorker(LoRANotSupportedWorkerBase, Worker):
"""A worker class that executes (a partition of) the model on a GPU.
Each worker is associated with a single XPU device. The worker is