mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-12 02:17:04 +08:00
[Hardware][CPU] Multi-LoRA implementation for the CPU backend (#11100)
Signed-off-by: Akshat Tripathi <akshat@krai.ai> Signed-off-by: Oleg Mosalov <oleg@krai.ai> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: Oleg Mosalov <oleg@krai.ai> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
f967e51f38
commit
8bddb73512
@ -75,6 +75,12 @@ function cpu_tests() {
|
||||
--num-prompts 20 \
|
||||
--endpoint /v1/completions \
|
||||
--tokenizer facebook/opt-125m"
|
||||
|
||||
# Run multi-lora tests
|
||||
docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
pytest -s -v \
|
||||
tests/lora/test_qwen2vl.py"
|
||||
}
|
||||
|
||||
# All of CPU tests are expected to be finished less than 25 mins.
|
||||
|
||||
@ -359,7 +359,7 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar
|
||||
- ✅
|
||||
- ✅
|
||||
- ✅
|
||||
- [✗](gh-pr:4830)
|
||||
- ✅
|
||||
- ✅
|
||||
* - <abbr title="Prompt Adapter">prmpt adptr</abbr>
|
||||
- ✅
|
||||
|
||||
@ -21,6 +21,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class ContextIDInfo(TypedDict):
|
||||
@ -65,13 +66,16 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):
|
||||
@pytest.fixture
|
||||
def dist_init():
|
||||
temp_file = tempfile.mkstemp()[1]
|
||||
init_distributed_environment(
|
||||
world_size=1,
|
||||
rank=0,
|
||||
distributed_init_method=f"file://{temp_file}",
|
||||
local_rank=0,
|
||||
backend="nccl",
|
||||
)
|
||||
|
||||
backend = "nccl"
|
||||
if current_platform.is_cpu():
|
||||
backend = "gloo"
|
||||
|
||||
init_distributed_environment(world_size=1,
|
||||
rank=0,
|
||||
distributed_init_method=f"file://{temp_file}",
|
||||
local_rank=0,
|
||||
backend=backend)
|
||||
initialize_model_parallel(1, 1)
|
||||
yield
|
||||
cleanup_dist_env_and_memory(shutdown_ray=True)
|
||||
@ -81,13 +85,15 @@ def dist_init():
|
||||
def dist_init_torch_only():
|
||||
if torch.distributed.is_initialized():
|
||||
return
|
||||
backend = "nccl"
|
||||
if current_platform.is_cpu():
|
||||
backend = "gloo"
|
||||
|
||||
temp_file = tempfile.mkstemp()[1]
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl",
|
||||
world_size=1,
|
||||
rank=0,
|
||||
init_method=f"file://{temp_file}",
|
||||
)
|
||||
torch.distributed.init_process_group(world_size=1,
|
||||
rank=0,
|
||||
init_method=f"file://{temp_file}",
|
||||
backend=backend)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@ -48,10 +48,14 @@ TOLERANCES = {
|
||||
torch.float32: (5e-3, 5e-3),
|
||||
torch.bfloat16: (3e-2, 2e-2),
|
||||
}
|
||||
# TODO: Modify this based on platform
|
||||
DEVICES = [
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not (current_platform.is_cuda_alike() or current_platform.is_cpu()),
|
||||
reason="Backend not supported")
|
||||
|
||||
DEVICES = ([
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
] if current_platform.is_cuda_alike() else ["cpu"])
|
||||
|
||||
#For GPU, we will launch different triton kernels between the prefill and decode
|
||||
# stages, so we need to verify this. prefill stage(True) or decode stage(False)
|
||||
@ -198,6 +202,10 @@ def check_punica_wrapper(punica_wrapper) -> bool:
|
||||
from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU
|
||||
|
||||
return type(punica_wrapper) is PunicaWrapperGPU
|
||||
elif current_platform.is_cpu():
|
||||
from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU
|
||||
|
||||
return type(punica_wrapper) is PunicaWrapperCPU
|
||||
else:
|
||||
return False
|
||||
|
||||
@ -211,7 +219,8 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
|
||||
# For multi-GPU testing of Triton kernel, we must explicitly set the CUDA
|
||||
# device, see: https://github.com/triton-lang/triton/issues/2925
|
||||
# Same below.
|
||||
torch.cuda.set_device(device)
|
||||
if current_platform.is_cuda_alike():
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
@ -313,7 +322,9 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
|
||||
def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
|
||||
vocab_size, stage) -> None:
|
||||
|
||||
torch.cuda.set_device(device)
|
||||
if current_platform.is_cuda_alike():
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device)
|
||||
@ -450,7 +461,9 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
|
||||
def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
|
||||
stage) -> None:
|
||||
|
||||
torch.cuda.set_device(device)
|
||||
if current_platform.is_cuda_alike():
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device)
|
||||
@ -582,7 +595,9 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
|
||||
def test_linear_replicated(dist_init, num_loras, device, stage,
|
||||
bias_enabled) -> None:
|
||||
|
||||
torch.cuda.set_device(device)
|
||||
if current_platform.is_cuda_alike():
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
torch.set_default_device(device)
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device)
|
||||
assert check_punica_wrapper(punica_wrapper)
|
||||
@ -695,7 +710,9 @@ def test_linear_replicated(dist_init, num_loras, device, stage,
|
||||
def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
||||
device, stage, bias_enabled) -> None:
|
||||
|
||||
torch.cuda.set_device(device)
|
||||
if current_platform.is_cuda_alike():
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
torch.set_default_device(device)
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device)
|
||||
assert check_punica_wrapper(punica_wrapper)
|
||||
@ -818,7 +835,9 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
||||
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
||||
device, stage, bias_enabled) -> None:
|
||||
|
||||
torch.cuda.set_device(device)
|
||||
if current_platform.is_cuda_alike():
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
torch.set_default_device(device)
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device)
|
||||
assert check_punica_wrapper(punica_wrapper)
|
||||
@ -971,6 +990,8 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
||||
@pytest.mark.parametrize("rotary_dim", [None, 32])
|
||||
@pytest.mark.parametrize("head_size", [32, 108])
|
||||
@pytest.mark.parametrize("seq_len", [11, 1024])
|
||||
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
|
||||
reason="Only CUDA backends are supported")
|
||||
def test_rotary_embedding_long_context(dist_init, num_loras, device,
|
||||
scaling_factors, max_position,
|
||||
is_neox_style, rotary_dim, head_size,
|
||||
|
||||
@ -20,6 +20,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
|
||||
WorkerLoRAManager)
|
||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
EMBEDDING_MODULES = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
@ -28,9 +29,9 @@ EMBEDDING_MODULES = {
|
||||
|
||||
EMBEDDING_PADDING_MODULES = ["lm_head"]
|
||||
|
||||
CUDA_DEVICES = [
|
||||
DEVICES = ([
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
] if current_platform.is_cuda_alike() else ["cpu"])
|
||||
|
||||
|
||||
def test_peft_helper(sql_lora_files):
|
||||
@ -83,7 +84,7 @@ def test_peft_helper(sql_lora_files):
|
||||
PEFTHelper.from_dict(config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_from_lora_tensors(sql_lora_files, device):
|
||||
tensors = load_file(
|
||||
os.path.join(sql_lora_files, "adapter_model.safetensors"))
|
||||
@ -171,7 +172,7 @@ def test_replace_submodules(dist_init, dummy_model):
|
||||
manager = LoRAModelManager(
|
||||
model, 1, 1, 1,
|
||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8),
|
||||
torch.device("cuda"))
|
||||
torch.device(DEVICES[0]))
|
||||
model = manager.model
|
||||
|
||||
assert isinstance(model.get_submodule("dense1"),
|
||||
@ -183,7 +184,7 @@ def test_replace_submodules(dist_init, dummy_model):
|
||||
RowParallelLinearWithLoRA)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_lora_model_manager(dist_init, dummy_model, device):
|
||||
model = dummy_model
|
||||
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
|
||||
@ -244,7 +245,7 @@ def test_lora_model_manager(dist_init, dummy_model, device):
|
||||
assert manager.punica_wrapper.device == device
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
|
||||
model = dummy_model
|
||||
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
|
||||
@ -336,7 +337,7 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
|
||||
assert manager.device == device
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_lru_lora_model_manager(dist_init, dummy_model, device):
|
||||
# This tests just the LRU cache functionality, everything else is
|
||||
# tested in test_lora_model_manager
|
||||
@ -466,7 +467,7 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
|
||||
assert manager.device == device
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
sql_lora_files, device):
|
||||
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
|
||||
@ -545,7 +546,7 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
device)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
sql_lora_files, device):
|
||||
# Should remove every LoRA not specified in the request.
|
||||
@ -621,7 +622,7 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
device)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_packed_loras(dist_init, dummy_model_gate_up, device):
|
||||
model = dummy_model_gate_up
|
||||
model.supported_lora_modules = ["gate_up_proj"]
|
||||
|
||||
@ -5,6 +5,7 @@ import torch
|
||||
|
||||
import vllm
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MODEL_PATH = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
|
||||
@ -31,7 +32,8 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int,
|
||||
@pytest.mark.parametrize("tp_size", [4])
|
||||
def test_mixtral_lora(mixtral_lora_files, tp_size):
|
||||
"""Original test, the LoRA model has the common target modules, not all"""
|
||||
if torch.cuda.device_count() < tp_size:
|
||||
if torch.cuda.device_count(
|
||||
) < tp_size and tp_size > 1 and current_platform.is_cuda_alike():
|
||||
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
|
||||
|
||||
prompts = [
|
||||
|
||||
@ -9,17 +9,16 @@ from threading import Lock
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.lora.ops.bgmv_expand import bgmv_expand
|
||||
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
|
||||
from vllm.lora.ops.bgmv_shrink import bgmv_shrink
|
||||
from vllm.lora.ops.sgmv_expand import sgmv_expand
|
||||
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
|
||||
from vllm.lora.ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
|
||||
import vllm.lora.ops.triton_ops # noqa: F401
|
||||
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
|
||||
bgmv_shrink, sgmv_expand,
|
||||
sgmv_expand_slice, sgmv_shrink)
|
||||
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .utils import (assert_close, generate_data,
|
||||
generate_data_for_expand_nslices,
|
||||
generate_data_for_nslices, ref_torch_groupgemm)
|
||||
generate_data_for_nslices)
|
||||
|
||||
HIDDEN_SIZES = [
|
||||
128,
|
||||
@ -113,7 +112,7 @@ DTYPES = [torch.float16, torch.bfloat16]
|
||||
MAX_RANKS = [32]
|
||||
SCALES = [0.5]
|
||||
SEED = [0]
|
||||
CUDA_DEVICES = [f"cuda:{0}"]
|
||||
DEVICES = [f"cuda:{0}"]
|
||||
|
||||
_dict_lock = Lock()
|
||||
|
||||
@ -127,7 +126,7 @@ _dict_lock = Lock()
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_punica_sgmv(
|
||||
batches: int,
|
||||
num_loras: int,
|
||||
@ -174,7 +173,7 @@ def test_punica_sgmv(
|
||||
# Preventing cache error pointer.
|
||||
with _dict_lock:
|
||||
_LORA_A_PTR_DICT.clear()
|
||||
sgmv_shrink(
|
||||
torch.ops.vllm.sgmv_shrink(
|
||||
inputs_tensor,
|
||||
lora_weights_lst,
|
||||
our_out_tensor,
|
||||
@ -187,20 +186,23 @@ def test_punica_sgmv(
|
||||
scaling,
|
||||
)
|
||||
for index in range(nslices):
|
||||
ref_torch_groupgemm(
|
||||
ref_out_tensor[index],
|
||||
sgmv_shrink(
|
||||
inputs_tensor,
|
||||
lora_weights_lst[index],
|
||||
lora_indices_tensor,
|
||||
ref_out_tensor[index],
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
scaling,
|
||||
op_type,
|
||||
)
|
||||
|
||||
else:
|
||||
with _dict_lock:
|
||||
_LORA_B_PTR_DICT.clear()
|
||||
sgmv_expand(
|
||||
torch.ops.vllm.sgmv_expand(
|
||||
inputs_tensor,
|
||||
lora_weights_lst,
|
||||
our_out_tensor,
|
||||
@ -213,21 +215,39 @@ def test_punica_sgmv(
|
||||
offset_start=0,
|
||||
add_inputs=True,
|
||||
)
|
||||
|
||||
slice_offset = 0
|
||||
for index in range(nslices):
|
||||
lora_weights = lora_weights_lst[index]
|
||||
ref_torch_groupgemm(
|
||||
ref_out_tensor[:, slice_offset:slice_offset + hidden_size],
|
||||
inputs_tensor[index],
|
||||
lora_weights,
|
||||
lora_indices_tensor,
|
||||
if nslices == 1:
|
||||
# Verify the torch's sgmv_expand op
|
||||
sgmv_expand(
|
||||
inputs_tensor[0],
|
||||
lora_weights_lst[0],
|
||||
ref_out_tensor,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
batches,
|
||||
1.0,
|
||||
op_type,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
add_inputs=True,
|
||||
)
|
||||
slice_offset += hidden_size
|
||||
else:
|
||||
slice_offset = 0
|
||||
for index in range(nslices):
|
||||
lora_weights = lora_weights_lst[index]
|
||||
sgmv_expand_slice(
|
||||
inputs_tensor[index],
|
||||
lora_weights,
|
||||
ref_out_tensor,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
slice_offset,
|
||||
hidden_size,
|
||||
add_inputs=True,
|
||||
)
|
||||
slice_offset += hidden_size
|
||||
|
||||
assert_close(our_out_tensor, ref_out_tensor)
|
||||
|
||||
@ -240,7 +260,7 @@ def test_punica_sgmv(
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_punica_bgmv(
|
||||
batches: int,
|
||||
num_loras: int,
|
||||
@ -276,31 +296,38 @@ def test_punica_bgmv(
|
||||
device,
|
||||
)
|
||||
if op_type == "shrink":
|
||||
bgmv_shrink(
|
||||
torch.ops.vllm.bgmv_shrink(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_out_tensor,
|
||||
indices,
|
||||
scaling,
|
||||
)
|
||||
|
||||
bgmv_shrink(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
ref_out_tensor,
|
||||
indices,
|
||||
scaling,
|
||||
)
|
||||
|
||||
else:
|
||||
bgmv_expand(
|
||||
torch.ops.vllm.bgmv_expand(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_out_tensor,
|
||||
indices,
|
||||
add_inputs=True,
|
||||
)
|
||||
ref_torch_groupgemm(
|
||||
ref_out_tensor,
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
batches,
|
||||
scaling if op_type == "shrink" else 1.0,
|
||||
op_type,
|
||||
)
|
||||
bgmv_expand(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
ref_out_tensor,
|
||||
indices,
|
||||
add_inputs=True,
|
||||
)
|
||||
|
||||
if op_type == "shrink":
|
||||
ref_out_tensor = ref_out_tensor.to(torch.float32)
|
||||
assert_close(our_out_tensor, ref_out_tensor)
|
||||
@ -313,7 +340,7 @@ def test_punica_bgmv(
|
||||
@pytest.mark.parametrize("nslices", [2, 3])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_punica_bgmv_expand_nslices(
|
||||
batches: int,
|
||||
num_loras: int,
|
||||
@ -350,7 +377,7 @@ def test_punica_bgmv_expand_nslices(
|
||||
slice_offset = 0
|
||||
for index in range(nslices):
|
||||
lora_weights = lora_weights_lst[index]
|
||||
bgmv_expand_slice(
|
||||
torch.ops.vllm.bgmv_expand_slice(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_outputs,
|
||||
@ -359,15 +386,14 @@ def test_punica_bgmv_expand_nslices(
|
||||
slice_size=hidden_size,
|
||||
add_inputs=True,
|
||||
)
|
||||
ref_torch_groupgemm(
|
||||
ref_outputs[:, slice_offset:slice_offset + hidden_size],
|
||||
bgmv_expand_slice(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
batches,
|
||||
1.0,
|
||||
op_type="expand",
|
||||
ref_outputs,
|
||||
indices,
|
||||
slice_offset,
|
||||
slice_size=hidden_size,
|
||||
add_inputs=True,
|
||||
)
|
||||
|
||||
slice_offset += hidden_size
|
||||
@ -9,19 +9,18 @@ import pytest
|
||||
import torch
|
||||
|
||||
# Enable custom op register
|
||||
import vllm.lora.ops.bgmv_expand
|
||||
import vllm.lora.ops.bgmv_expand_slice
|
||||
import vllm.lora.ops.bgmv_shrink
|
||||
import vllm.lora.ops.sgmv_expand
|
||||
import vllm.lora.ops.sgmv_shrink # noqa: F401
|
||||
from vllm.lora.ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
|
||||
import vllm.lora.ops.triton_ops # noqa: F401
|
||||
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
|
||||
bgmv_shrink, sgmv_expand,
|
||||
sgmv_expand_slice, sgmv_shrink)
|
||||
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .utils import (assert_close, generate_data,
|
||||
generate_data_for_expand_nslices,
|
||||
generate_data_for_nslices, ref_torch_groupgemm)
|
||||
generate_data_for_nslices)
|
||||
|
||||
HIDDEN_SIZES = [4097]
|
||||
HIDDEN_SIZES = [2049]
|
||||
|
||||
BATCHES = [1, 4, 16, 32]
|
||||
NUM_LORA = [1, 8, 32, 128]
|
||||
@ -29,15 +28,7 @@ DTYPES = [torch.float16, torch.bfloat16]
|
||||
MAX_RANKS = [1, 4, 8, 16, 32, 64, 128, 256]
|
||||
SCALES = [0.5]
|
||||
SEED = [0]
|
||||
CUDA_DEVICES = [f"cuda:{0}"]
|
||||
|
||||
# Unlike test_punica_sizes.py, we directly utilize custom op for
|
||||
# testing, which verifies the correct registration of these ops.
|
||||
bgmv_expand = torch.ops.vllm.bgmv_expand
|
||||
bgmv_expand_slice = torch.ops.vllm.bgmv_expand_slice
|
||||
bgmv_shrink = torch.ops.vllm.bgmv_shrink
|
||||
sgmv_expand = torch.ops.vllm.sgmv_expand
|
||||
sgmv_shrink = torch.ops.vllm.sgmv_shrink
|
||||
DEVICES = [f"cuda:{0}"]
|
||||
|
||||
_dict_lock = Lock()
|
||||
|
||||
@ -51,7 +42,7 @@ _dict_lock = Lock()
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_punica_sgmv(
|
||||
batches: int,
|
||||
num_loras: int,
|
||||
@ -98,7 +89,7 @@ def test_punica_sgmv(
|
||||
# Preventing cache error pointer.
|
||||
with _dict_lock:
|
||||
_LORA_A_PTR_DICT.clear()
|
||||
sgmv_shrink(
|
||||
torch.ops.vllm.sgmv_shrink(
|
||||
inputs_tensor,
|
||||
lora_weights_lst,
|
||||
our_out_tensor,
|
||||
@ -111,20 +102,23 @@ def test_punica_sgmv(
|
||||
scaling,
|
||||
)
|
||||
for index in range(nslices):
|
||||
ref_torch_groupgemm(
|
||||
ref_out_tensor[index],
|
||||
sgmv_shrink(
|
||||
inputs_tensor,
|
||||
lora_weights_lst[index],
|
||||
lora_indices_tensor,
|
||||
ref_out_tensor[index],
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
scaling,
|
||||
op_type,
|
||||
)
|
||||
|
||||
else:
|
||||
with _dict_lock:
|
||||
_LORA_B_PTR_DICT.clear()
|
||||
sgmv_expand(
|
||||
torch.ops.vllm.sgmv_expand(
|
||||
inputs_tensor,
|
||||
lora_weights_lst,
|
||||
our_out_tensor,
|
||||
@ -137,21 +131,39 @@ def test_punica_sgmv(
|
||||
offset_start=0,
|
||||
add_inputs=True,
|
||||
)
|
||||
|
||||
slice_offset = 0
|
||||
for index in range(nslices):
|
||||
lora_weights = lora_weights_lst[index]
|
||||
ref_torch_groupgemm(
|
||||
ref_out_tensor[:, slice_offset:slice_offset + hidden_size],
|
||||
inputs_tensor[index],
|
||||
lora_weights,
|
||||
lora_indices_tensor,
|
||||
if nslices == 1:
|
||||
# Verify the torch's sgmv_expand op
|
||||
sgmv_expand(
|
||||
inputs_tensor[0],
|
||||
lora_weights_lst[0],
|
||||
ref_out_tensor,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
batches,
|
||||
1.0,
|
||||
op_type,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
add_inputs=True,
|
||||
)
|
||||
slice_offset += hidden_size
|
||||
else:
|
||||
for index in range(nslices):
|
||||
lora_weights = lora_weights_lst[index]
|
||||
sgmv_expand_slice(
|
||||
inputs_tensor[index],
|
||||
lora_weights,
|
||||
ref_out_tensor,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
slice_offset,
|
||||
hidden_size,
|
||||
add_inputs=True,
|
||||
)
|
||||
slice_offset += hidden_size
|
||||
|
||||
assert_close(our_out_tensor, ref_out_tensor)
|
||||
|
||||
@ -164,7 +176,7 @@ def test_punica_sgmv(
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_punica_bgmv(
|
||||
batches: int,
|
||||
num_loras: int,
|
||||
@ -176,7 +188,6 @@ def test_punica_bgmv(
|
||||
seed: int,
|
||||
device: str,
|
||||
):
|
||||
|
||||
torch.set_default_device(device)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
@ -201,32 +212,38 @@ def test_punica_bgmv(
|
||||
device,
|
||||
)
|
||||
if op_type == "shrink":
|
||||
bgmv_shrink(
|
||||
torch.ops.vllm.bgmv_shrink(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_out_tensor,
|
||||
indices,
|
||||
scaling,
|
||||
)
|
||||
else:
|
||||
|
||||
bgmv_expand(
|
||||
bgmv_shrink(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
ref_out_tensor,
|
||||
indices,
|
||||
scaling,
|
||||
)
|
||||
|
||||
else:
|
||||
torch.ops.vllm.bgmv_expand(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_out_tensor,
|
||||
indices,
|
||||
add_inputs=True,
|
||||
)
|
||||
ref_torch_groupgemm(
|
||||
ref_out_tensor,
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
batches,
|
||||
scaling if op_type == "shrink" else 1.0,
|
||||
op_type,
|
||||
)
|
||||
bgmv_expand(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
ref_out_tensor,
|
||||
indices,
|
||||
add_inputs=True,
|
||||
)
|
||||
|
||||
if op_type == "shrink":
|
||||
ref_out_tensor = ref_out_tensor.to(torch.float32)
|
||||
assert_close(our_out_tensor, ref_out_tensor)
|
||||
@ -239,7 +256,7 @@ def test_punica_bgmv(
|
||||
@pytest.mark.parametrize("nslices", [2, 3])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_punica_bgmv_expand_nslices(
|
||||
batches: int,
|
||||
num_loras: int,
|
||||
@ -276,7 +293,7 @@ def test_punica_bgmv_expand_nslices(
|
||||
slice_offset = 0
|
||||
for index in range(nslices):
|
||||
lora_weights = lora_weights_lst[index]
|
||||
bgmv_expand_slice(
|
||||
torch.ops.vllm.bgmv_expand_slice(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_outputs,
|
||||
@ -285,15 +302,14 @@ def test_punica_bgmv_expand_nslices(
|
||||
slice_size=hidden_size,
|
||||
add_inputs=True,
|
||||
)
|
||||
ref_torch_groupgemm(
|
||||
ref_outputs[:, slice_offset:slice_offset + hidden_size],
|
||||
bgmv_expand_slice(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
batches,
|
||||
1.0,
|
||||
op_type="expand",
|
||||
ref_outputs,
|
||||
indices,
|
||||
slice_offset,
|
||||
slice_size=hidden_size,
|
||||
add_inputs=True,
|
||||
)
|
||||
|
||||
slice_offset += hidden_size
|
||||
@ -72,7 +72,8 @@ def do_sample(llm: vllm.LLM,
|
||||
@pytest.mark.parametrize("tp_size", [1])
|
||||
def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model,
|
||||
tp_size):
|
||||
if num_gpus_available < tp_size:
|
||||
if num_gpus_available < tp_size and \
|
||||
tp_size > 1 and current_platform.is_cuda_alike():
|
||||
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
|
||||
|
||||
llm = vllm.LLM(
|
||||
|
||||
@ -104,33 +104,6 @@ def assert_close(a, b):
|
||||
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
def ref_torch_groupgemm(
|
||||
out_tensor,
|
||||
inputs,
|
||||
lora_weights,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
batches,
|
||||
scaling,
|
||||
op_type,
|
||||
) -> torch.Tensor:
|
||||
out_list = []
|
||||
current_offset = 0
|
||||
for lora_index, b_length in zip(range(batches), seq_len_tensor):
|
||||
input_weight = inputs[current_offset:b_length + current_offset, :]
|
||||
current_offset += b_length
|
||||
lora_weight = lora_weights[lora_indices_tensor[lora_index]]
|
||||
result = torch.nn.functional.linear(input_weight, lora_weight)
|
||||
result *= scaling
|
||||
out_list.append(result)
|
||||
cat_result = torch.cat(out_list, dim=0)
|
||||
if op_type == "expand":
|
||||
out_tensor += cat_result
|
||||
else:
|
||||
out_tensor.copy_(cat_result)
|
||||
return
|
||||
|
||||
|
||||
def generate_data(
|
||||
batches,
|
||||
hidden_size,
|
||||
|
||||
@ -22,9 +22,6 @@ class CPUExecutor(ExecutorBase):
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
assert self.device_config.device_type == "cpu"
|
||||
# Reminder: Please update docs/source/features/compatibility_matrix.md
|
||||
# If the feature combo become valid
|
||||
assert self.lora_config is None, "cpu backend doesn't support LoRA"
|
||||
|
||||
#
|
||||
# Environment variables for CPU executor
|
||||
|
||||
13
vllm/lora/ops/torch_ops/__init__.py
Normal file
13
vllm/lora/ops/torch_ops/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
from vllm.lora.ops.torch_ops.lora_ops import bgmv_expand # noqa: F401
|
||||
from vllm.lora.ops.torch_ops.lora_ops import (bgmv_expand_slice, bgmv_shrink,
|
||||
sgmv_expand, sgmv_expand_slice,
|
||||
sgmv_shrink)
|
||||
|
||||
__all__ = [
|
||||
"bgmv_expand",
|
||||
"bgmv_expand_slice",
|
||||
"bgmv_shrink",
|
||||
"sgmv_expand",
|
||||
"sgmv_expand_slice",
|
||||
"sgmv_shrink",
|
||||
]
|
||||
113
vllm/lora/ops/torch_ops/lora_ops.py
Normal file
113
vllm/lora/ops/torch_ops/lora_ops.py
Normal file
@ -0,0 +1,113 @@
|
||||
import torch
|
||||
|
||||
|
||||
def sgmv_expand(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
add_inputs: bool = False):
|
||||
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
|
||||
seq_len_tensor)
|
||||
|
||||
bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices,
|
||||
add_inputs)
|
||||
|
||||
|
||||
def bgmv_expand(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
add_inputs: bool = True):
|
||||
selected_loras = lora_b_weights[lora_indices_tensor].to(
|
||||
dtype=output_tensor.dtype)
|
||||
if len(selected_loras.shape) == 4:
|
||||
selected_loras = selected_loras.squeeze(dim=1)
|
||||
inputs = inputs.to(dtype=output_tensor.dtype)
|
||||
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
|
||||
|
||||
limit = output_tensor.shape[0]
|
||||
if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
|
||||
limit = 1
|
||||
|
||||
if add_inputs:
|
||||
output_tensor[:, :outputs.shape[1]] += outputs[:limit, :]
|
||||
else:
|
||||
output_tensor[:, :outputs.shape[1]] = outputs[:limit, :]
|
||||
|
||||
|
||||
def sgmv_shrink(
|
||||
inputs: torch.Tensor,
|
||||
lora_a_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
scaling: float,
|
||||
):
|
||||
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
|
||||
seq_len_tensor)
|
||||
|
||||
bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices,
|
||||
scaling)
|
||||
|
||||
|
||||
def bgmv_shrink(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
scaling: float = 1.0):
|
||||
selected_loras = lora_b_weights[lora_indices_tensor].to(
|
||||
dtype=output_tensor.dtype)
|
||||
if len(selected_loras.shape) == 4:
|
||||
selected_loras = selected_loras.squeeze(dim=1)
|
||||
inputs = inputs.to(dtype=output_tensor.dtype)
|
||||
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
|
||||
|
||||
output_tensor[:, :outputs.shape[1]] = scaling * outputs[:]
|
||||
|
||||
|
||||
def sgmv_expand_slice(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = False):
|
||||
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
|
||||
seq_len_tensor)
|
||||
|
||||
bgmv_expand_slice(inputs, lora_b_weights, output_tensor, exploded_indices,
|
||||
slice_offset, slice_size, add_inputs)
|
||||
|
||||
|
||||
def bgmv_expand_slice(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = True):
|
||||
selected_loras = lora_b_weights[lora_indices_tensor].to(
|
||||
dtype=output_tensor.dtype)
|
||||
inputs = inputs.to(dtype=output_tensor.dtype)
|
||||
if len(selected_loras.shape) == 4:
|
||||
selected_loras = selected_loras.squeeze(dim=1)
|
||||
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
|
||||
|
||||
if add_inputs:
|
||||
output_tensor[:, slice_offset:slice_offset + slice_size] += outputs[:]
|
||||
else:
|
||||
output_tensor[:, slice_offset:slice_offset + slice_size] = outputs[:]
|
||||
13
vllm/lora/ops/triton_ops/__init__.py
Normal file
13
vllm/lora/ops/triton_ops/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
from vllm.lora.ops.triton_ops.bgmv_expand import bgmv_expand
|
||||
from vllm.lora.ops.triton_ops.bgmv_expand_slice import bgmv_expand_slice
|
||||
from vllm.lora.ops.triton_ops.bgmv_shrink import bgmv_shrink
|
||||
from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand
|
||||
from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink # noqa: F401
|
||||
|
||||
__all__ = [
|
||||
"bgmv_expand",
|
||||
"bgmv_expand_slice",
|
||||
"bgmv_shrink",
|
||||
"sgmv_expand",
|
||||
"sgmv_shrink",
|
||||
]
|
||||
346
vllm/lora/punica_wrapper/punica_cpu.py
Normal file
346
vllm/lora/punica_wrapper/punica_cpu.py
Normal file
@ -0,0 +1,346 @@
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
|
||||
bgmv_shrink, sgmv_expand,
|
||||
sgmv_expand_slice, sgmv_shrink)
|
||||
|
||||
from .punica_base import PunicaWrapperBase
|
||||
|
||||
|
||||
# The platforms that are compatible with the PyTorch-native implementation can
|
||||
# inherit this class
|
||||
class PunicaWrapperCPU(PunicaWrapperBase):
|
||||
"""
|
||||
PunicaWrapperCPU is designed to manage and provide metadata for the punica
|
||||
kernel. The main function is to maintain the state information for
|
||||
Multi-LoRA, and to provide the interface for the pytorch punica ops.
|
||||
"""
|
||||
|
||||
def __init__(self, max_num_batched_tokens: int, max_batches: int,
|
||||
device: Union[torch.device, str], **kwargs):
|
||||
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
|
||||
device)
|
||||
|
||||
def _shrink_prefill(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
scale: float,
|
||||
):
|
||||
#No LoRA request, so return directly
|
||||
if self.no_lora:
|
||||
return
|
||||
sgmv_shrink(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
*self.prefill_metadata,
|
||||
scale,
|
||||
)
|
||||
|
||||
def _shrink_decode(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
scale: float,
|
||||
):
|
||||
bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
|
||||
|
||||
def _expand_prefill(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
add_inputs: bool,
|
||||
):
|
||||
#No LoRA request, so return directly
|
||||
if self.no_lora:
|
||||
return
|
||||
sgmv_expand(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
*self.prefill_metadata,
|
||||
add_inputs,
|
||||
)
|
||||
|
||||
def _expand_decode(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
add_inputs: bool,
|
||||
):
|
||||
bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs)
|
||||
|
||||
def _expand_slice_prefill(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
y_offset: int,
|
||||
y_slice_size: int,
|
||||
add_inputs: bool,
|
||||
):
|
||||
#No LoRA request, so return directly
|
||||
if self.no_lora:
|
||||
return
|
||||
sgmv_expand_slice(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
*self.prefill_metadata,
|
||||
y_offset,
|
||||
y_slice_size,
|
||||
add_inputs,
|
||||
)
|
||||
|
||||
def _expand_slice_decode(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
y_offset: int,
|
||||
y_slice_size: int,
|
||||
add_inputs: bool,
|
||||
):
|
||||
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
|
||||
y_slice_size, add_inputs)
|
||||
|
||||
def _apply_expand(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
y_offset: int,
|
||||
y_slice_size: int,
|
||||
add_inputs: bool = True,
|
||||
):
|
||||
"""
|
||||
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
|
||||
computation, which is suitable for the
|
||||
GEMM of lora'b.
|
||||
"""
|
||||
|
||||
expand_slice_fun: Callable = (self._expand_slice_prefill
|
||||
if self.is_prefill else
|
||||
self._expand_slice_decode)
|
||||
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs)
|
||||
|
||||
def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor,
|
||||
w_t_all: torch.Tensor, scale: float):
|
||||
"""
|
||||
Perform the ` y+=x@w_t_all` computation, which is suitable for the
|
||||
GEMM of lora'a.
|
||||
When `is_prefill is` true, it indicates that it is currently the
|
||||
prefill stage, and the `_shrink_prefill` function should be called.
|
||||
Otherwise, it is the decode stage, and the _shrink_decode function
|
||||
should be called.
|
||||
"""
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
shrink_fun: Callable = (self._shrink_prefill
|
||||
if self.is_prefill else self._shrink_decode)
|
||||
shrink_fun(y, x, w_t_all, scale)
|
||||
y = y.view_as(y_org)
|
||||
|
||||
def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
||||
x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
|
||||
scale: float, **kwargs):
|
||||
"""
|
||||
Performs GEMM for multiple slices of lora_a.
|
||||
When `is_prefill is` true, it indicates that it is currently the
|
||||
prefill stage, and the `_shrink_prefill` function should be called.
|
||||
Otherwise, it is the decode stage, and the _shrink_decode function
|
||||
should be called.
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_a_stacked)):
|
||||
y[i] += (x @ lora_a_stacked[i]) * scale
|
||||
|
||||
Args:
|
||||
y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
|
||||
x (torch.Tensor): Input tensor
|
||||
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
|
||||
scale (float): Scaling factor for the operation
|
||||
"""
|
||||
|
||||
x = x.view(-1, x.shape[-1])
|
||||
# TODO fuse these kernels
|
||||
for slice_idx in range(len(lora_a_stacked)):
|
||||
self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx],
|
||||
scale)
|
||||
|
||||
def add_expand(self,
|
||||
y: torch.Tensor,
|
||||
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
||||
lora_b_stacked: Tuple[torch.Tensor, ...],
|
||||
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
||||
output_slices: Tuple[int, ...],
|
||||
offset_start: int = 0,
|
||||
add_inputs=True,
|
||||
**kwargs) -> None:
|
||||
"""
|
||||
Performs GEMM and bias addition for multiple slices of lora_b.
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_b_stacked)):
|
||||
slice = output_slices[i]
|
||||
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
|
||||
lora_bias_stacked[i]
|
||||
offset += slice
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
|
||||
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
|
||||
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
|
||||
bias's weight
|
||||
output_slices (Tuple[int, ...]): Every slice's size
|
||||
add_inputs (bool): Defaults to True.
|
||||
"""
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
offset_left = offset_start
|
||||
if lora_bias_stacked is not None:
|
||||
self._apply_bias(self.token_lora_indices, y, output_slices,
|
||||
lora_bias_stacked)
|
||||
for slice_idx in range(len(lora_b_stacked)):
|
||||
self._apply_expand(
|
||||
y,
|
||||
x[slice_idx],
|
||||
lora_b_stacked[slice_idx],
|
||||
offset_left,
|
||||
output_slices[slice_idx],
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
offset_left += output_slices[slice_idx]
|
||||
y = y.view_as(y_org)
|
||||
|
||||
def add_lora_embedding(self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_b_stacked: torch.Tensor,
|
||||
add_inputs: bool = True,
|
||||
**kwargs) -> None:
|
||||
"""
|
||||
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
|
||||
|
||||
Semantics:
|
||||
y += x @ lora_b_stacked
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (torch.Tensor): Input tensor.
|
||||
lora_b_stacked (torch.Tensor): lora_b's weights.
|
||||
add_inputs (bool): Default to True.
|
||||
"""
|
||||
|
||||
# Embedding layer only need expand op
|
||||
expand_fun: Callable = (self._expand_prefill
|
||||
if self.is_prefill else self._expand_decode)
|
||||
expand_fun(y, x, lora_b_stacked, add_inputs)
|
||||
|
||||
def add_lora_linear(self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: Tuple[torch.Tensor, ...],
|
||||
lora_b_stacked: Tuple[torch.Tensor, ...],
|
||||
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
||||
scale: float,
|
||||
output_slices: Tuple[int, ...],
|
||||
*,
|
||||
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
|
||||
**kwargs) -> None:
|
||||
"""
|
||||
Applicable to linear-related lora.
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_a_stacked)):
|
||||
y[i] += (
|
||||
x[i].unsqueeze(0)
|
||||
@ lora_a_stacked[indices[i], layer_idx, :, :]
|
||||
@ lora_b_stacked[indices[i], layer_idx, :, :]
|
||||
* scale
|
||||
).squeeze(0)+lora_bias_stacked[i]
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor. Will be changed in-place.
|
||||
x (torch.Tensor): Input tensor
|
||||
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight.
|
||||
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight.
|
||||
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias.
|
||||
scale (float): Scaling factor.
|
||||
output_slices (Tuple[int, ...]): Every slice's size.
|
||||
buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
|
||||
"""
|
||||
|
||||
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
|
||||
if lora_bias_stacked is not None:
|
||||
assert len(lora_bias_stacked) == len(output_slices)
|
||||
y = self._apply_bias(self.token_lora_indices, y, output_slices,
|
||||
lora_bias_stacked)
|
||||
|
||||
if buffer is None:
|
||||
r = lora_b_stacked[0].size(-1)
|
||||
# We set the buffer to be float32 by default, consistent with the
|
||||
# triton op
|
||||
buffer = tuple(
|
||||
torch.zeros(
|
||||
(x.size(0), r), dtype=torch.float32, device=x.device)
|
||||
for _ in range(len(output_slices)))
|
||||
self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
|
||||
self.add_expand(y,
|
||||
buffer,
|
||||
lora_b_stacked,
|
||||
None,
|
||||
output_slices,
|
||||
add_inputs=True,
|
||||
**kwargs)
|
||||
|
||||
def add_lora_logits(self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: torch.Tensor,
|
||||
lora_b_stacked: torch.Tensor,
|
||||
scale,
|
||||
*,
|
||||
buffer: Optional[torch.Tensor] = None,
|
||||
**kwargs) -> None:
|
||||
"""
|
||||
Applies lora specifically for LogitsProcessorWithLoRA.
|
||||
|
||||
Semantics:
|
||||
buffer = (x @ lora_a_stacked) * scale
|
||||
y += buffer @ lora_b_stacked
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (torch.Tensor): Input tensor.
|
||||
lora_a_stacked (torch.Tensor): lora_a's weights.
|
||||
lora_b_stacked (torch.Tensor):lora_b's weights.
|
||||
scale (float): Scaling factor.
|
||||
buffer (Optional[torch.Tensor]):Default to None.
|
||||
"""
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
x = x.view(-1, x.shape[-1])
|
||||
r = lora_b_stacked.size(-1)
|
||||
if buffer is None:
|
||||
# We set the buffer to be float32 by default, consistent with the
|
||||
# triton op
|
||||
buffer = torch.zeros((x.size(0), r),
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
# LogitsProcessorWithLoRA always using bgmv.
|
||||
bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale)
|
||||
bgmv_expand(buffer,
|
||||
lora_b_stacked,
|
||||
y,
|
||||
self.sampler_indices,
|
||||
add_inputs=True)
|
||||
y = y.view_as(y_org)
|
||||
@ -12,11 +12,11 @@ import torch
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm.lora.ops.bgmv_expand import bgmv_expand
|
||||
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
|
||||
from vllm.lora.ops.bgmv_shrink import bgmv_shrink
|
||||
from vllm.lora.ops.sgmv_expand import sgmv_expand
|
||||
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
|
||||
from vllm.lora.ops.triton_ops import bgmv_expand
|
||||
from vllm.lora.ops.triton_ops import bgmv_expand_slice
|
||||
from vllm.lora.ops.triton_ops import bgmv_shrink
|
||||
from vllm.lora.ops.triton_ops import sgmv_expand
|
||||
from vllm.lora.ops.triton_ops import sgmv_shrink
|
||||
|
||||
from .punica_base import PunicaWrapperBase
|
||||
|
||||
|
||||
@ -12,6 +12,11 @@ def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase:
|
||||
from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU
|
||||
logger.info_once("Using PunicaWrapperGPU.")
|
||||
return PunicaWrapperGPU(*args, **kwargs)
|
||||
elif current_platform.is_cpu():
|
||||
# Lazy import to avoid ImportError
|
||||
from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU
|
||||
logger.info_once("Using PunicaWrapperCPU.")
|
||||
return PunicaWrapperCPU(*args, **kwargs)
|
||||
elif current_platform.is_hpu():
|
||||
# Lazy import to avoid ImportError
|
||||
from vllm.lora.punica_wrapper.punica_hpu import PunicaWrapperHPU
|
||||
|
||||
@ -2,8 +2,8 @@ import dataclasses
|
||||
import weakref
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar,
|
||||
Union)
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Type,
|
||||
TypeVar, Union)
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -12,10 +12,14 @@ from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models import supports_lora, supports_multimodal
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
||||
MultiModalKwargs, MultiModalPlaceholderMap)
|
||||
from vllm.sequence import (IntermediateTensors, SequenceData,
|
||||
@ -49,6 +53,8 @@ class ModelInputForCPU(ModelRunnerInputBase):
|
||||
virtual_engine: Optional[int] = None
|
||||
seq_lens: Optional[List[int]] = None
|
||||
query_lens: Optional[List[int]] = None
|
||||
lora_mapping: Optional["LoRAMapping"] = None
|
||||
lora_requests: Optional[Set[LoRARequest]] = None
|
||||
|
||||
def as_broadcastable_tensor_dict(
|
||||
self) -> Dict[str, Union[int, torch.Tensor]]:
|
||||
@ -57,6 +63,8 @@ class ModelInputForCPU(ModelRunnerInputBase):
|
||||
"input_positions": self.input_positions,
|
||||
"token_type_ids": self.token_type_ids,
|
||||
"multi_modal_kwargs": self.multi_modal_kwargs,
|
||||
"lora_requests": self.lora_requests,
|
||||
"lora_mapping": self.lora_mapping,
|
||||
}
|
||||
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||
|
||||
@ -143,7 +151,11 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
||||
or runner.cache_config.enable_prefix_caching)
|
||||
self.model_input_cls = self.runner._model_input_cls
|
||||
self.attn_backend = self.runner.attn_backend
|
||||
self.sliding_window = self.runner.sliding_window
|
||||
self.block_size = self.runner.block_size
|
||||
self.device = self.runner.device
|
||||
self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
|
||||
self.enable_lora = self.runner.lora_config is not None
|
||||
self.input_data = ModelInputForCPUBuilder.ModelInputData(
|
||||
self.runner.model_config.uses_mrope)
|
||||
self.att_metadata_builder = self.runner.attn_backend.get_builder_cls()(
|
||||
@ -183,15 +195,28 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
||||
attn_metadata = self.att_metadata_builder.build(
|
||||
input_data.seq_lens, input_data.query_lens, -1, -1)
|
||||
|
||||
return self.model_input_cls(
|
||||
input_tokens=input_tokens,
|
||||
input_positions=input_positions,
|
||||
token_type_ids=token_type_ids,
|
||||
seq_lens=input_data.seq_lens,
|
||||
query_lens=input_data.query_lens,
|
||||
attn_metadata=attn_metadata,
|
||||
multi_modal_kwargs=multi_modal_kwargs,
|
||||
)
|
||||
is_prompt = (self.seq_group_metadata_list[0].is_prompt
|
||||
if self.seq_group_metadata_list else None)
|
||||
# LoRA data.
|
||||
lora_requests = set()
|
||||
lora_mapping = None
|
||||
if self.enable_lora:
|
||||
lora_requests = set(seq.lora_request
|
||||
for seq in self.seq_group_metadata_list
|
||||
if seq.lora_request is not None)
|
||||
|
||||
lora_mapping = self._prepare_lora_input(
|
||||
self.seq_group_metadata_list, is_prompt)
|
||||
|
||||
return self.model_input_cls(input_tokens=input_tokens,
|
||||
input_positions=input_positions,
|
||||
token_type_ids=token_type_ids,
|
||||
seq_lens=input_data.seq_lens,
|
||||
query_lens=input_data.query_lens,
|
||||
attn_metadata=attn_metadata,
|
||||
multi_modal_kwargs=multi_modal_kwargs,
|
||||
lora_mapping=lora_mapping,
|
||||
lora_requests=lora_requests)
|
||||
|
||||
def _build_input_data(self):
|
||||
for seq_group_metadata in self.seq_group_metadata_list:
|
||||
@ -381,6 +406,24 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
||||
self.input_data.multi_modal_placeholder_maps[modality].extend(
|
||||
placeholder_map)
|
||||
|
||||
def _prepare_lora_input(
|
||||
self, seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
is_prefill: bool) -> LoRAMapping:
|
||||
index_mapping = []
|
||||
prompt_mapping = []
|
||||
for seq in seq_group_metadata_list:
|
||||
lora_id = seq.lora_int_id
|
||||
query_len = seq.token_chunk_size
|
||||
|
||||
index_mapping += [lora_id] * query_len
|
||||
prompt_mapping += [lora_id] * (
|
||||
query_len if seq.sampling_params
|
||||
and seq.sampling_params.prompt_logprobs is not None else 1)
|
||||
|
||||
return LoRAMapping(index_mapping=tuple(index_mapping),
|
||||
prompt_mapping=tuple(prompt_mapping),
|
||||
is_prefill=is_prefill)
|
||||
|
||||
|
||||
class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
|
||||
"""
|
||||
@ -431,10 +474,41 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
|
||||
|
||||
# Lazy initialization.
|
||||
self.model: nn.Module # Set after init_Model
|
||||
# Set after load_model.
|
||||
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.model = get_model(vllm_config=self.vllm_config)
|
||||
|
||||
if self.lora_config:
|
||||
assert supports_lora(
|
||||
self.model
|
||||
), f"{self.model.__class__.__name__} does not support LoRA yet."
|
||||
|
||||
if supports_multimodal(self.model):
|
||||
logger.warning("Regarding multimodal models, vLLM currently "
|
||||
"only supports adding LoRA to language model.")
|
||||
|
||||
# It's necessary to distinguish between the max_position_embeddings
|
||||
# of VLMs and LLMs.
|
||||
if hasattr(self.model.config, "max_position_embeddings"):
|
||||
max_pos_embeddings = self.model.config.max_position_embeddings
|
||||
else:
|
||||
max_pos_embeddings = (
|
||||
self.model.config.text_config.max_position_embeddings)
|
||||
|
||||
self.lora_manager = LRUCacheWorkerLoRAManager(
|
||||
self.scheduler_config.max_num_seqs,
|
||||
self.scheduler_config.max_num_batched_tokens,
|
||||
self.vocab_size,
|
||||
self.lora_config,
|
||||
self.device,
|
||||
self.model.embedding_modules,
|
||||
self.model.embedding_padding_modules,
|
||||
max_position_embeddings=max_pos_embeddings,
|
||||
)
|
||||
self.model = self.lora_manager.create_lora_manager(self.model)
|
||||
|
||||
def _prepare_model_input_tensors(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
@ -459,6 +533,37 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
|
||||
def vocab_size(self) -> int:
|
||||
return self.model_config.get_vocab_size()
|
||||
|
||||
def remove_all_loras(self):
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
self.lora_manager.remove_all_adapters()
|
||||
|
||||
def set_active_loras(self, lora_requests: Set[LoRARequest],
|
||||
lora_mapping: LoRAMapping) -> None:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.add_adapter(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.remove_adapter(lora_id)
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.pin_adapter(lora_id)
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.list_adapters()
|
||||
|
||||
|
||||
class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
|
||||
_model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = (
|
||||
@ -515,6 +620,12 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
|
||||
raise ValueError(
|
||||
"CPU worker does not support multi-step execution.")
|
||||
|
||||
if self.lora_config:
|
||||
assert model_input.lora_requests is not None
|
||||
assert model_input.lora_mapping is not None
|
||||
self.set_active_loras(model_input.lora_requests,
|
||||
model_input.lora_mapping)
|
||||
|
||||
model_executable = self.model
|
||||
|
||||
multimodal_kwargs = {}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""A CPU worker class."""
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
from typing import Dict, List, Optional, Set, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@ -11,14 +11,14 @@ from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache
|
||||
from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
|
||||
from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase
|
||||
from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner
|
||||
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
|
||||
LoraNotSupportedWorkerBase, WorkerBase,
|
||||
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
|
||||
WorkerInput)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -111,7 +111,7 @@ class CPUCacheEngine:
|
||||
return dtype_size * total
|
||||
|
||||
|
||||
class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
class CPUWorker(LocalOrDistributedWorkerBase):
|
||||
"""A worker class that executes (a partition of) the model on a CPU socket.
|
||||
|
||||
Each worker is associated with a single CPU socket. The worker is
|
||||
@ -266,6 +266,18 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
# Initialize the cache.
|
||||
self._init_cache_engine()
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.model_runner.add_lora(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
return self.model_runner.remove_lora(lora_id)
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self.model_runner.pin_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
return self.model_runner.list_loras()
|
||||
|
||||
def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None:
|
||||
"""Raise errors if the num_cpu_blocks is invalid.
|
||||
"""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user