[LoRA][2/2]Remove LoRA extra vocab (#28545)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-11-21 09:46:43 +08:00 committed by GitHub
parent df44df0143
commit 9875be6431
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 133 additions and 528 deletions

View File

@ -250,6 +250,16 @@ def olmoe_lora_files():
return snapshot_download(repo_id="jeeejeee/olmoe-instruct-text2sql-spider")
@pytest.fixture(scope="session")
def qwen3_lora_files():
return snapshot_download(repo_id="charent/self_cognition_Alice")
@pytest.fixture(scope="session")
def llama32_lora_files():
return snapshot_download(repo_id="jeeejeee/llama32-3b-text2sql-spider")
@pytest.fixture
def reset_default_device():
"""

View File

@ -136,7 +136,6 @@ def populate_loras(
id_to_index: list[int | None],
layer: BaseLayerWithLoRA,
layer_weights: torch.Tensor,
generate_embeddings_tensor: int = 0,
repeats: int = 1,
) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]:
"""This method populates the lora layers with lora weights.
@ -148,8 +147,6 @@ def populate_loras(
layer: the LoRAlayer to populate.
layer_weights: the PyTorch tensor containing the layer's
weights.
generate_embeddings_tensor: whether to generate an
embeddings tensor for each LoRA.
repeats: must only be set for column parallel packed
layers. Indicates the number of loras to compose
together to create a single lora layer.
@ -171,7 +168,6 @@ def populate_loras(
sublora = DummyLoRAManager(layer_weights.device).init_random_lora(
module_name=f"fake_{i}",
weight=layer_weights,
generate_embeddings_tensor=generate_embeddings_tensor,
)
sublora.lora_b = sublora.lora_b[
(sublora_len * i) : (sublora_len * (i + 1)), :
@ -185,7 +181,6 @@ def populate_loras(
slot_idx,
lora_a=lora.lora_a,
lora_b=lora.lora_b,
embeddings_tensor=lora.embeddings_tensor,
)
lora_dict[lora_id] = lora
@ -306,7 +301,6 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
id_to_index,
max_loras,
vocab_size,
lora_config.lora_extra_vocab_size,
)
lora_result = lora_embedding(torch.cat(inputs))
@ -344,7 +338,6 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
id_to_index,
max_loras,
vocab_size,
lora_config.lora_extra_vocab_size,
)
lora_result = lora_embedding(torch.cat(inputs))
@ -354,149 +347,6 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
@torch.inference_mode()
# @pytest.mark.skip(
# reason="Fails when loras are in any slot other than the first.")
@pytest.mark.parametrize("num_loras", [1, 2, 4])
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
@pytest.mark.parametrize("stage", STAGES)
def test_embeddings_with_new_embeddings(
dist_init, num_loras, device, vocab_size, stage
) -> None:
if current_platform.is_cuda_alike():
torch.cuda.set_device(device)
torch.set_default_device(device)
max_loras = 8
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig(
max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16
)
def create_random_embedding_layer():
embedding = VocabParallelEmbedding(vocab_size, 256)
embedding_data = torch.rand_like(embedding.weight.data)
embedding.weight.data = embedding_data
embedding.weight.data[vocab_size:, :] = 0
expanded_embedding = VocabParallelEmbedding(
vocab_size + lora_config.lora_extra_vocab_size * max_loras,
256,
org_num_embeddings=vocab_size,
)
expanded_embedding.weight.data[:vocab_size, :] = embedding_data
# We need to deepcopy the embedding as it will be modified
# in place
lora_embedding = VocabParallelEmbeddingWithLoRA(deepcopy(expanded_embedding))
lora_embedding.create_lora_weights(max_loras, lora_config)
return expanded_embedding, lora_embedding
for i in range(NUM_RANDOM_SEEDS):
set_random_seed(i)
id_to_index = get_random_id_to_index(num_loras, max_loras)
expanded_embedding, lora_embedding = create_random_embedding_layer()
lora_dict, _ = populate_loras(
id_to_index,
layer=lora_embedding,
layer_weights=torch.zeros(
(256, vocab_size + lora_config.lora_extra_vocab_size)
),
generate_embeddings_tensor=256,
)
lora_embedding.set_mapping(punica_wrapper)
# All embeddings tensors have the same shape.
embeddings_tensors = [
lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys())
]
embeddings_tensor_len = embeddings_tensors[0].shape[0]
# Add empty embeddings_tensors for unoccupied lora slots.
for _ in range(max_loras - len(embeddings_tensors)):
embeddings_tensors.append(torch.zeros(embeddings_tensors[0].shape))
inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=list(lora_dict.keys()),
num_inputs=num_loras * 3,
input_size=(200,),
input_range=(1, vocab_size),
device=device,
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
punica_wrapper.update_metadata(
lora_mapping,
id_to_index,
max_loras,
vocab_size,
lora_config.lora_extra_vocab_size,
)
original_inputs = deepcopy(inputs)
# Force some of the inputs to be in the extended embeddings range
# to guarantee that their behavior is tested.
for input_, original_input_, lora_id in zip(
inputs, original_inputs, prompt_mapping
):
embedding_id = lora_id - 1
input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len)
original_input_[-1] = vocab_size
input_[-2] = vocab_size + ((embedding_id + 1) * embeddings_tensor_len - 1)
original_input_[-2] = vocab_size + embeddings_tensor_len - 1
expanded_embedding.weight[
vocab_size : vocab_size + (embeddings_tensor_len * max_loras)
] = torch.cat(embeddings_tensors)
lora_result = lora_embedding(torch.cat(original_inputs))
expected_results: list[torch.Tensor] = []
for input_, original_input_, lora_id in zip(
inputs, original_inputs, prompt_mapping
):
lora = lora_dict[lora_id]
result = expanded_embedding(input_)
after_a = F.embedding(
original_input_,
lora.lora_a.T,
)
result += after_a @ lora.lora_b.T
expected_results.append(result)
expected_result = torch.cat(expected_results)
rtol, atol = TOLERANCES[lora_result.dtype]
torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
# Check that resetting the lora weights succeeds
for slot_idx in range(max_loras):
lora_embedding.reset_lora(slot_idx)
inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=[0],
num_inputs=num_loras * 3,
input_size=(200,),
input_range=(1, vocab_size),
device=device,
)
original_inputs = deepcopy(inputs)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
punica_wrapper.update_metadata(
lora_mapping,
id_to_index,
max_loras,
vocab_size,
lora_config.lora_extra_vocab_size,
)
lora_result = lora_embedding(torch.cat(original_inputs))
expected_result = expanded_embedding(torch.cat(inputs))
rtol, atol = TOLERANCES[lora_result.dtype]
torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4])
@pytest.mark.parametrize("device", DEVICES)
@ -518,16 +368,13 @@ def test_lm_head_logits_processor(
def _pretest():
linear = ParallelLMHead(
vocab_size + lora_config.lora_extra_vocab_size,
1024,
vocab_size,
num_embeddings=vocab_size,
embedding_dim=1024,
params_dtype=torch.float16,
)
linear.weight.data = torch.rand_like(linear.weight.data)
linear.weight.data[:, vocab_size:] = 0
logits_processor = LogitsProcessor(
vocab_size + lora_config.lora_extra_vocab_size, vocab_size
)
logits_processor = LogitsProcessor(vocab_size)
lora_logits_processor = LogitsProcessorWithLoRA(
logits_processor, 1024, linear.weight.dtype, linear.weight.device, None
)
@ -541,15 +388,12 @@ def test_lm_head_logits_processor(
id_to_index = get_random_id_to_index(num_loras, max_loras)
linear, logits_processor, lora_logits_processor = _pretest()
lora_logits_processor.set_mapping(punica_wrapper)
# NOTE: all the generated loras share the same embeddings tensor.
lora_dict, _ = populate_loras(
id_to_index,
layer=lora_logits_processor,
layer_weights=linear.weight,
generate_embeddings_tensor=1024,
)
embeddings_tensor = list(lora_dict.values())[0].embeddings_tensor
embeddings_tensor_len = embeddings_tensor.shape[0]
inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=list(lora_dict.keys()),
@ -565,7 +409,6 @@ def test_lm_head_logits_processor(
id_to_index,
max_loras,
vocab_size,
lora_config.lora_extra_vocab_size,
)
input_ = torch.rand(20, 1024)
@ -575,23 +418,16 @@ def test_lm_head_logits_processor(
original_lm_head = deepcopy(linear)
linear.weight[
logits_processor.org_vocab_size : logits_processor.org_vocab_size
+ embeddings_tensor_len
] = embeddings_tensor
logits_processor.org_vocab_size = vocab_size + lora_config.lora_extra_vocab_size
expected_results: list[torch.Tensor] = []
for input_, lora_id in zip(inputs, prompt_mapping):
lora = lora_dict[lora_id]
result = logits_processor._get_logits(
hidden_states=input_, lm_head=linear, embedding_bias=None
)
result[:, vocab_size + embeddings_tensor_len :] = float("-inf")
result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
expected_results.append(result)
expected_result = torch.cat(expected_results)
logits_processor.org_vocab_size = vocab_size
# Check that resetting the lora weights succeeds
@ -612,7 +448,6 @@ def test_lm_head_logits_processor(
id_to_index,
max_loras,
vocab_size,
lora_config.lora_extra_vocab_size,
)
lora_result = lora_logits_processor._get_logits(
@ -694,7 +529,6 @@ def test_linear_replicated(
id_to_index,
max_loras,
512,
lora_config.lora_extra_vocab_size,
)
lora_result = lora_linear(torch.cat(inputs))[0]
@ -726,7 +560,10 @@ def test_linear_replicated(
lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
punica_wrapper.update_metadata(
lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size
lora_mapping,
id_to_index,
max_loras,
512,
)
lora_result = lora_linear(torch.cat(inputs))[0]
@ -817,7 +654,6 @@ def test_linear_parallel(
id_to_index,
max_loras,
512,
lora_config.lora_extra_vocab_size,
)
lora_result = lora_linear(torch.cat(inputs))[0]
@ -849,7 +685,10 @@ def test_linear_parallel(
lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
punica_wrapper.update_metadata(
lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size
lora_mapping,
id_to_index,
max_loras,
512,
)
lora_result = lora_linear(torch.cat(inputs))[0]
@ -963,7 +802,6 @@ def test_column_parallel_packed(
id_to_index,
max_loras,
512,
lora_config.lora_extra_vocab_size,
)
lora_result = lora_linear(torch.cat(inputs))[0]
@ -1000,7 +838,6 @@ def test_column_parallel_packed(
id_to_index,
max_loras,
512,
lora_config.lora_extra_vocab_size,
)
lora_result = lora_linear(torch.cat(inputs))[0]

View File

@ -13,17 +13,27 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from ..utils import VLLM_PATH, create_new_process_for_each_test, multi_gpu_test
MODEL_PATH = "meta-llama/Llama-2-7b-hf"
PROMPT_TEMPLATE = """<|eot_id|><|start_header_id|>user<|end_header_id|>
I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.
"
##Instruction:
candidate_poll contains tables such as candidate, people. Table candidate has columns such as Candidate_ID, People_ID, Poll_Source, Date, Support_rate, Consider_rate, Oppose_rate, Unsure_rate. Candidate_ID is the primary key.
Table people has columns such as People_ID, Sex, Name, Date_of_Birth, Height, Weight. People_ID is the primary key.
The People_ID of candidate is the foreign key of People_ID of people.
###Input:
{context}
###Response:<|eot_id|><|start_header_id|>assistant<|end_header_id|>
""" # noqa: E501
EXPECTED_LORA_OUTPUT = [
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501
" SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ",
" SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ", # noqa: E501
" SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", # noqa: E501
" SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ",
" SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' ", # noqa: E501
"SELECT count(*) FROM candidate",
"SELECT count(*) FROM candidate",
"SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501
"SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501
]
MODEL_PATH = "meta-llama/Llama-3.2-3B-Instruct"
def do_sample(
llm: vllm.LLM,
@ -32,18 +42,19 @@ def do_sample(
tensorizer_config_dict: dict | None = None,
) -> list[str]:
prompts = [
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]", # noqa: E501
PROMPT_TEMPLATE.format(context="How many candidates are there?"),
PROMPT_TEMPLATE.format(context="Count the number of candidates."),
PROMPT_TEMPLATE.format(
context="Which poll resource provided the most number of candidate information?" # noqa: E501
),
PROMPT_TEMPLATE.format(
context="Return the poll resource associated with the most candidates."
),
]
sampling_params = vllm.SamplingParams(
temperature=0, max_tokens=256, skip_special_tokens=False, stop=["[/assistant]"]
temperature=0, max_tokens=64, stop=["<|im_end|>"]
)
if tensorizer_config_dict is not None:
outputs = llm.generate(
prompts,
@ -75,13 +86,15 @@ def do_sample(
return generated_texts
def generate_and_test(llm, sql_lora_files, tensorizer_config_dict: dict | None = None):
def generate_and_test(
llm, llama32_lora_files, tensorizer_config_dict: dict | None = None
):
print("lora adapter created")
print("lora 1")
assert (
do_sample(
llm,
sql_lora_files,
llama32_lora_files,
tensorizer_config_dict=tensorizer_config_dict,
lora_id=1,
)
@ -92,7 +105,7 @@ def generate_and_test(llm, sql_lora_files, tensorizer_config_dict: dict | None =
assert (
do_sample(
llm,
sql_lora_files,
llama32_lora_files,
tensorizer_config_dict=tensorizer_config_dict,
lora_id=2,
)
@ -104,51 +117,52 @@ def generate_and_test(llm, sql_lora_files, tensorizer_config_dict: dict | None =
@create_new_process_for_each_test()
@pytest.mark.parametrize("cudagraph_specialize_lora", [True, False])
def test_llama_lora(sql_lora_files, cudagraph_specialize_lora: bool):
def test_llama_lora(llama32_lora_files, cudagraph_specialize_lora: bool):
llm = vllm.LLM(
MODEL_PATH,
tokenizer=sql_lora_files,
enable_lora=True,
# also test odd max_num_seqs
max_num_seqs=13,
max_num_seqs=7,
max_model_len=1024,
max_loras=4,
compilation_config=vllm.config.CompilationConfig(
cudagraph_specialize_lora=cudagraph_specialize_lora,
),
)
generate_and_test(llm, sql_lora_files)
generate_and_test(llm, llama32_lora_files)
@multi_gpu_test(num_gpus=4)
def test_llama_lora_tp4(sql_lora_files):
def test_llama_lora_tp4(llama32_lora_files):
llm = vllm.LLM(
MODEL_PATH,
tokenizer=sql_lora_files,
enable_lora=True,
max_num_seqs=16,
max_num_seqs=7,
max_model_len=1024,
max_loras=4,
tensor_parallel_size=4,
)
generate_and_test(llm, sql_lora_files)
generate_and_test(llm, llama32_lora_files)
@multi_gpu_test(num_gpus=4)
def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):
def test_llama_lora_tp4_fully_sharded_loras(llama32_lora_files):
llm = vllm.LLM(
MODEL_PATH,
tokenizer=sql_lora_files,
enable_lora=True,
max_num_seqs=16,
max_num_seqs=8,
max_loras=4,
max_model_len=1024,
tensor_parallel_size=4,
fully_sharded_loras=True,
)
generate_and_test(llm, sql_lora_files)
generate_and_test(llm, llama32_lora_files)
@multi_gpu_test(num_gpus=2)
def test_tp2_serialize_and_deserialize_lora(
tmp_path, sql_lora_files, sql_lora_huggingface_id
tmp_path,
llama32_lora_files,
):
# Run the tensorizing of the LoRA adapter and the model in a subprocess
# to guarantee cleanup
@ -157,7 +171,7 @@ def test_tp2_serialize_and_deserialize_lora(
model_name = "model-rank-%03d.tensors"
model_ref = MODEL_PATH
lora_path = sql_lora_huggingface_id
lora_path = llama32_lora_files
suffix = "test"
try:
result = subprocess.run(
@ -195,12 +209,12 @@ def test_tp2_serialize_and_deserialize_lora(
loaded_llm = LLM(
model=model_ref,
tokenizer=sql_lora_files,
load_format="tensorizer",
enable_lora=True,
enforce_eager=True,
model_loader_extra_config=tensorizer_config,
max_num_seqs=13,
max_num_seqs=7,
max_model_len=1024,
tensor_parallel_size=2,
max_loras=2,
)
@ -211,7 +225,7 @@ def test_tp2_serialize_and_deserialize_lora(
print("lora 1")
assert (
do_sample(
loaded_llm, sql_lora_files, tensorizer_config_dict=tc_as_dict, lora_id=1
loaded_llm, llama32_lora_files, tensorizer_config_dict=tc_as_dict, lora_id=1
)
== EXPECTED_LORA_OUTPUT
)

View File

@ -13,8 +13,8 @@ from vllm.entrypoints.openai.api_server import (
from vllm.lora.request import LoRARequest
from vllm.v1.engine.llm_engine import LLMEngine
MODEL_PATH = "meta-llama/Llama-2-7b-hf"
LORA_MODULE_PATH = "yard1/llama-2-7b-sql-lora-test"
MODEL_PATH = "Qwen/Qwen3-0.6B"
LORA_MODULE_PATH = "charent/self_cognition_Alice"
LORA_RANK = 8

View File

@ -48,9 +48,6 @@ DEFAULT_DTYPE = torch.get_default_dtype()
@pytest.mark.parametrize("device", DEVICES)
def test_from_lora_tensors(sql_lora_files, device):
tensors = load_file(os.path.join(sql_lora_files, "adapter_model.safetensors"))
new_embeddings = load_file(
os.path.join(sql_lora_files, "new_embeddings.safetensors")
)
peft_helper = PEFTHelper.from_local_dir(
sql_lora_files, max_position_embeddings=4096
@ -60,7 +57,6 @@ def test_from_lora_tensors(sql_lora_files, device):
tensors,
peft_helper=peft_helper,
device=device,
embeddings=new_embeddings,
embedding_modules=EMBEDDING_MODULES,
embedding_padding_modules=EMBEDDING_PADDING_MODULES,
)
@ -76,18 +72,6 @@ def test_from_lora_tensors(sql_lora_files, device):
f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
)
assert lora.lora_a.shape[0] == 8
embeddings_module = next(
(k for k in EMBEDDING_MODULES if k in module_name), None
)
if embeddings_module:
assert torch.equal(
lora.embeddings_tensor,
new_embeddings[EMBEDDING_MODULES[embeddings_module]].to(
device=lora.embeddings_tensor.device
),
)
else:
assert lora.embeddings_tensor is None
def create_lora(
@ -552,9 +536,7 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, tmp_path
worker_adapter_manager = WorkerLoRAManager(
vllm_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES
)
worker_adapter_manager.vocab_size = (
dummy_model_gate_up.unpadded_vocab_size - lora_config.lora_extra_vocab_size
)
worker_adapter_manager.vocab_size = dummy_model_gate_up.unpadded_vocab_size
worker_adapter_manager.create_lora_manager(dummy_model_gate_up)
dummy_lora_files = f"{tmp_path}/lora_adapter"

View File

@ -20,11 +20,12 @@ from vllm.lora.models import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.v1.worker.gpu_worker import Worker
MODEL_PATH = "Qwen/Qwen3-0.6B"
NUM_LORAS = 16
@patch.dict(os.environ, {"RANK": "0"})
def test_worker_apply_lora(sql_lora_files):
def test_worker_apply_lora(qwen3_lora_files):
def set_active_loras(worker: Worker, lora_requests: list[LoRARequest]):
lora_mapping = LoRAMapping([], [])
@ -34,9 +35,10 @@ def test_worker_apply_lora(sql_lora_files):
vllm_config = VllmConfig(
model_config=ModelConfig(
"meta-llama/Llama-2-7b-hf",
MODEL_PATH,
seed=0,
dtype="float16",
max_model_len=127,
enforce_eager=True,
),
load_config=LoadConfig(
@ -73,7 +75,7 @@ def test_worker_apply_lora(sql_lora_files):
assert worker.list_loras() == set()
lora_requests = [
LoRARequest(str(i + 1), i + 1, sql_lora_files) for i in range(NUM_LORAS)
LoRARequest(str(i + 1), i + 1, qwen3_lora_files) for i in range(NUM_LORAS)
]
set_active_loras(worker, lora_requests)

View File

@ -28,7 +28,6 @@ class DummyLoRAManager:
module_name: str,
weight: torch.Tensor,
rank: int = 8,
generate_embeddings_tensor: int = 0,
):
lora = LoRALayerWeights(
module_name,
@ -41,13 +40,6 @@ class DummyLoRAManager:
[weight.shape[0], rank], dtype=weight.dtype, device=self._device
),
)
if generate_embeddings_tensor:
lora.embeddings_tensor = torch.rand(
5,
generate_embeddings_tensor,
dtype=weight.dtype,
device=self._device,
)
self.set_module_lora(module_name, lora)
return lora

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from typing import TYPE_CHECKING, Any, ClassVar, Literal
from typing import TYPE_CHECKING, Any, Literal
import torch
from pydantic import ConfigDict, Field, model_validator
@ -11,7 +11,6 @@ from typing_extensions import Self
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.platforms import current_platform
if TYPE_CHECKING:
from vllm.config import ModelConfig
@ -46,19 +45,6 @@ class LoRAConfig:
`max_loras`."""
lora_dtype: torch.dtype | LoRADType = "auto"
"""Data type for LoRA. If auto, will default to base model dtype."""
lora_extra_vocab_size: LoRAExtraVocabSize = Field(
default=256,
deprecated=(
"`lora_extra_vocab_size` is deprecated and will be removed "
"in v0.12.0. Additional vocabulary support for "
"LoRA adapters is being phased out."
),
)
"""(Deprecated) Maximum size of extra vocabulary that can be present in a
LoRA adapter. Will be removed in v0.12.0."""
lora_vocab_padding_size: ClassVar[int] = (
current_platform.get_lora_vocab_padding_size()
)
default_mm_loras: dict[str, str] | None = None
"""Dictionary mapping specific modalities to LoRA model paths; this field
is only applicable to multimodal models and should be leveraged when a
@ -87,8 +73,6 @@ class LoRAConfig:
factors.append(self.max_loras)
factors.append(self.fully_sharded_loras)
factors.append(self.lora_dtype)
factors.append(self.lora_extra_vocab_size)
factors.append(self.lora_vocab_padding_size)
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str

View File

@ -484,7 +484,6 @@ class EngineArgs:
fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
max_cpu_loras: int | None = LoRAConfig.max_cpu_loras
lora_dtype: str | torch.dtype | None = LoRAConfig.lora_dtype
lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size
ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
num_gpu_blocks_override: int | None = CacheConfig.num_gpu_blocks_override
@ -1011,9 +1010,6 @@ class EngineArgs:
)
lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"])
lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"])
lora_group.add_argument(
"--lora-extra-vocab-size", **lora_kwargs["lora_extra_vocab_size"]
)
lora_group.add_argument(
"--lora-dtype",
**lora_kwargs["lora_dtype"],
@ -1680,7 +1676,6 @@ class EngineArgs:
max_loras=self.max_loras,
default_mm_loras=self.default_mm_loras,
fully_sharded_loras=self.fully_sharded_loras,
lora_extra_vocab_size=self.lora_extra_vocab_size,
lora_dtype=self.lora_dtype,
max_cpu_loras=self.max_cpu_loras
if self.max_cpu_loras and self.max_cpu_loras > 0

View File

@ -44,7 +44,6 @@ class BaseLayerWithLoRA(nn.Module):
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: torch.Tensor | None,
):
"""Overwrites lora tensors at index."""
...

View File

@ -96,7 +96,6 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: torch.Tensor | None,
):
# Except for QKVParallelLinearWithLoRA and
# MergedColumnParallelLinearWithLoRA, all other linear LoRA layers

View File

@ -248,7 +248,6 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: torch.Tensor | None,
):
self.reset_lora(index)

View File

@ -406,8 +406,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: torch.Tensor | None,
bias: torch.Tensor | None = None,
):
"""Overwrites lora tensors at index."""
self.reset_lora(index)

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
import torch
import torch.nn as nn
@ -108,22 +107,13 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
(
max_loras,
1,
# Pad for kernel compatibility
math.ceil(
self.base_layer.vocab_size / lora_config.lora_vocab_padding_size
)
* lora_config.lora_vocab_padding_size,
self.base_layer.vocab_size,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
self.embeddings_tensors = torch.full(
(max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
fill_value=float("-inf"),
dtype=self.dtype,
device=self.device,
)
if self.sharded_to_full_mapping is not None:
self.sharded_to_full_mapping_gpu = torch.tensor(
self.sharded_to_full_mapping, device=self.device, dtype=torch.long
@ -134,14 +124,12 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
self.embeddings_tensors[index] = float("-inf")
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: torch.Tensor | None,
):
self.reset_lora(index)
self.lora_a_stacked[index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_(
@ -150,12 +138,6 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
self.lora_b_stacked[index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_(
lora_b, non_blocking=True
)
if embeddings_tensor is not None:
self.embeddings_tensors[
index,
: embeddings_tensor.shape[0],
: embeddings_tensor.shape[1],
] = embeddings_tensor
def _get_logits(
self,
@ -193,39 +175,6 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
# token_id: [0, 1, 2, 3, 4, 5, -1, -1]
logits = logits[:, self.sharded_to_full_mapping_gpu]
lora_logits = torch.empty(
self.embeddings_tensors.shape[0] + 1,
self.embeddings_tensors.shape[1],
hidden_states.shape[0],
dtype=self.embeddings_tensors.dtype,
device=self.embeddings_tensors.device,
)
torch.matmul(self.embeddings_tensors, hidden_states.T, out=lora_logits[:-1])
neg_inf, pos_inf = current_platform.get_infinity_values(lora_logits.dtype)
lora_logits[-1] = neg_inf
lora_logits = lora_logits.mT
indices_padded = self.punica_wrapper.sampler_indices_padded
if current_platform.is_tpu() or current_platform.is_xpu():
indices_padded = indices_padded[: logits.size(0)]
lora_logits = (
lora_logits.reshape(
lora_logits.shape[0] * lora_logits.shape[1],
lora_logits.shape[2],
)
.index_select(0, indices_padded)
.nan_to_num_(nan=neg_inf, posinf=pos_inf, neginf=neg_inf)
)
logits[
:,
self.base_layer.org_vocab_size : self.base_layer.org_vocab_size
+ lora_logits.shape[1],
] = lora_logits
lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_logits(
logits, hidden_states, self.lora_a_stacked, self.lora_b_stacked, 1.0
)

View File

@ -46,19 +46,10 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self.embeddings_slice = None
self.embeddings_weights = None
self.embeddings_tensors = torch.zeros(
(
max_loras,
lora_config.lora_extra_vocab_size,
self.base_layer.embedding_dim,
),
dtype=self.base_layer.weight.dtype,
device=self.base_layer.weight.device,
)
self.lora_a_stacked = torch.zeros(
(
max_loras,
self.base_layer.org_vocab_size + lora_config.lora_extra_vocab_size,
self.base_layer.org_vocab_size,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
@ -82,14 +73,12 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
self.embeddings_tensors[index] = 0
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: torch.Tensor | None,
):
self.reset_lora(index)
# NOTE self.lora_a_stacked is row-major, and lora_a is col-major,
@ -100,36 +89,18 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self.lora_b_stacked[index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_(
lora_b, non_blocking=True
)
if embeddings_tensor is not None:
self.embeddings_tensors[
index,
: embeddings_tensor.shape[0],
: embeddings_tensor.shape[1],
].copy_(embeddings_tensor, non_blocking=True)
if self.embeddings_slice is not None:
# TODO(yard1): Optimize this copy, we don't need to copy
# everything, just the modified part
embeddings = self.embeddings_tensors.view(
self.embeddings_tensors.shape[0] * self.embeddings_tensors.shape[1],
self.embeddings_tensors.shape[2],
)[self.embeddings_slice[0] : self.embeddings_slice[1]]
assert self.embeddings_weights is not None
self.embeddings_weights[: embeddings.shape[0]].copy_(embeddings)
def forward(self, x: torch.Tensor) -> torch.Tensor:
added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1, 1, 0)
# NB: Don't use torch.narrow here. torch.narrow triggers some
# Dynamic Shape specialization in torch.compile
num_tokens = x.shape[0]
indices_1 = self.punica_wrapper._embeddings_indices[1][:num_tokens]
indices_0 = self.punica_wrapper._embeddings_indices[0][:num_tokens]
full_lora_a_embeddings = F.embedding(
x + indices_1,
self.lora_a_stacked_2d,
)
full_output = self.base_layer.forward(x + (indices_0 * added_tokens_mask))
full_output = self.base_layer.forward(x)
full_output_org = full_output
if full_output.ndim == 3:

View File

@ -21,7 +21,6 @@ class LoRALayerWeights:
lora_alpha: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: torch.Tensor | None = None,
scaling: float | None = None,
) -> None:
self.module_name = module_name
@ -29,7 +28,6 @@ class LoRALayerWeights:
self.lora_alpha = lora_alpha
self.lora_a = lora_a
self.lora_b = lora_b
self.embeddings_tensor = embeddings_tensor
if scaling is None:
self.scaling = self.lora_alpha / self.rank
@ -56,18 +54,11 @@ class LoRALayerWeights:
def is_packed(self) -> bool:
return False
@property
def extra_vocab_size(self) -> int:
return (
self.embeddings_tensor.shape[0] if self.embeddings_tensor is not None else 0
)
@classmethod
def from_config(
cls,
module_name: str,
peft_helper: PEFTHelper,
embeddings_tensor: torch.Tensor | None = None,
) -> "LoRALayerWeights":
# lora_a and lora_b are set to None for config-based construction
return cls(
@ -76,7 +67,6 @@ class LoRALayerWeights:
peft_helper.lora_alpha,
None,
None,
embeddings_tensor,
peft_helper.vllm_lora_scaling_factor,
)
@ -89,7 +79,6 @@ class LoRALayerWeights:
rank: int,
dtype: torch.dtype,
device: torch.types.Device,
embeddings_tensor_dim: int | None = None,
) -> "LoRALayerWeights":
pin_memory = str(device) == "cpu" and is_pin_memory_available()
lora_a = torch.zeros(
@ -99,24 +88,12 @@ class LoRALayerWeights:
[output_dim, rank], dtype=dtype, device=device, pin_memory=pin_memory
)
embeddings_tensor = (
torch.rand(
10,
embeddings_tensor_dim,
dtype=dtype,
device=device,
pin_memory=pin_memory,
)
if embeddings_tensor_dim
else None
)
return cls(
module_name,
rank=rank,
lora_alpha=1,
lora_a=lora_a,
lora_b=lora_b,
embeddings_tensor=embeddings_tensor,
)
@ -139,7 +116,6 @@ class PackedLoRALayerWeights(LoRALayerWeights):
lora_a=lora_a,
lora_b=lora_b,
scaling=scaling, # type: ignore
embeddings_tensor=None,
)
self.lora_alphas = lora_alphas
if scaling is None:

View File

@ -21,6 +21,7 @@ from vllm.lora.utils import (
from_layer,
from_layer_logits_processor,
get_supported_lora_modules,
is_base_embeddding_weights,
is_regex_target_modules,
parse_fine_tuned_lora_name,
process_packed_modules_mapping,
@ -93,14 +94,6 @@ class LoRAModel:
loras=self.loras.copy(),
)
@property
def extra_vocab_size(self) -> int:
return (
max(lora.extra_vocab_size for lora in self.loras.values())
if self.loras
else 0
)
def get_lora(self, module_name: str) -> LoRALayerWeights | None:
"""Get LoRA for a given module by name"""
return self.loras.get(module_name, None)
@ -117,7 +110,6 @@ class LoRAModel:
peft_helper: PEFTHelper,
device: str = "cuda",
dtype: torch.dtype | None = None,
embeddings: dict[str, torch.Tensor] | None = None,
target_embedding_padding: int | None = None,
embedding_modules: dict[str, str] | None = None,
embedding_padding_modules: list[str] | None = None,
@ -127,24 +119,14 @@ class LoRAModel:
pin_memory = str(device) == "cpu" and is_pin_memory_available()
loras: dict[str, LoRALayerWeights] = {}
for tensor_name, tensor in tensors.items():
if is_base_embeddding_weights(tensor_name):
continue
module_name, is_lora_a = parse_fine_tuned_lora_name(
tensor_name, weights_mapper
)
if module_name not in loras:
lora_embeddings_tensor = None
if embeddings:
assert embedding_modules is not None
embeddings_module = next(
(k for k in embedding_modules if k in module_name), None
)
if embeddings_module:
lora_embeddings_tensor = embeddings[
embedding_modules[embeddings_module]
].to(device=device, dtype=dtype)
if pin_memory:
lora_embeddings_tensor = lora_embeddings_tensor.pin_memory()
loras[module_name] = LoRALayerWeights.from_config(
module_name, peft_helper, lora_embeddings_tensor
module_name, peft_helper
)
if is_lora_a:
@ -206,15 +188,17 @@ class LoRAModel:
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt")
new_embeddings_tensor_path = os.path.join(
lora_dir, "new_embeddings.safetensors"
)
new_embeddings_bin_file_path = os.path.join(lora_dir, "new_embeddings.bin")
# new_embeddings_tensor_path = os.path.join(
# lora_dir, "new_embeddings.safetensors"
# )
# new_embeddings_bin_file_path = os.path.join(lora_dir, "new_embeddings.bin")
tensors: dict[str, torch.Tensor] = {}
unexpected_modules: list[list[str] | str] = []
def check_unexpected_modules(modules: dict):
for lora_module in modules.keys(): # noqa
if is_base_embeddding_weights(lora_module):
continue
module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper)
# Handle FSDP file format where experts.base_layer is the
# gate_up_proj and experts is the down_proj
@ -300,21 +284,12 @@ class LoRAModel:
else:
raise ValueError(f"{lora_dir} doesn't contain tensors")
embeddings = None
if os.path.isfile(new_embeddings_tensor_path):
embeddings = safetensors.torch.load_file(new_embeddings_tensor_path)
elif os.path.isfile(new_embeddings_bin_file_path):
embeddings = torch.load(
new_embeddings_bin_file_path, map_location=device, weights_only=True
)
return cls.from_lora_tensors(
lora_model_id=get_lora_id() if lora_model_id is None else lora_model_id,
tensors=tensors,
peft_helper=peft_helper,
device=device,
dtype=dtype,
embeddings=embeddings,
target_embedding_padding=target_embedding_padding,
embedding_modules=embedding_modules,
embedding_padding_modules=embedding_padding_modules,
@ -474,7 +449,6 @@ class LoRAModelManager:
index,
module_lora.lora_a,
module_lora.lora_b,
module_lora.embeddings_tensor,
)
else:
module.reset_lora(index)
@ -505,7 +479,6 @@ class LoRAModelManager:
self.lora_index_to_id,
self.lora_slots + 1,
self.vocab_size,
self.lora_config.lora_extra_vocab_size,
)
def remove_all_adapters(self):
@ -616,7 +589,6 @@ class LoRAModelManager:
if parts[-1] in embedding_modules:
input_dim = (
module.base_layer.org_vocab_size
+ self.lora_config.lora_extra_vocab_size
if hasattr(module.base_layer, "org_vocab_size")
else module.base_layer.weight.shape[1]
)
@ -625,11 +597,6 @@ class LoRAModelManager:
if hasattr(module.base_layer, "embedding_dim")
else module.base_layer.weight.shape[0]
)
embeddings_tensor_dim = (
module.base_layer.embedding_dim
if hasattr(module.base_layer, "embedding_dim")
else module.base_layer.weight.shape[1]
)
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name,
input_dim,
@ -637,7 +604,6 @@ class LoRAModelManager:
rank,
module.lora_a_stacked[0].dtype,
"cpu",
embeddings_tensor_dim=embeddings_tensor_dim,
)
else:
lora = LoRALayerWeights.create_dummy_lora_weights(

View File

@ -31,7 +31,6 @@ class PunicaWrapperABC(ABC):
lora_index_to_id: list[int | None],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
**kwargs,
) -> None:
"""
@ -172,8 +171,11 @@ class PunicaWrapperBase(PunicaWrapperABC):
lora_index_to_id: list[int | None],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
):
# NOTE We have remove lora extra vocab support for now. So we set
# extra_vocab_size alwayzs to 0, and extra_vocab_size will be removed.
extra_vocab_size = 0
(
base_indices,
sampler_indices,
@ -285,12 +287,9 @@ class PunicaWrapperBase(PunicaWrapperABC):
lora_index_to_id: list[int | None],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
**kwargs,
):
self._update_base_metadata(
mapping, lora_index_to_id, max_loras, vocab_size, extra_vocab_size
)
self._update_base_metadata(mapping, lora_index_to_id, max_loras, vocab_size)
if mapping.is_prefill:
# Update metadata required for prefill-related operators.

View File

@ -65,13 +65,10 @@ class PunicaWrapperGPU(PunicaWrapperBase):
lora_index_to_id: list[int | None],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
**kwargs,
):
self.is_prefill = mapping.is_prefill
self._update_base_metadata(
mapping, lora_index_to_id, max_loras, vocab_size, extra_vocab_size
)
self._update_base_metadata(mapping, lora_index_to_id, max_loras, vocab_size)
# Prepare cuda kernel metadata tensors
self.token_mapping_meta.prepare_tensors(self.token_lora_indices)

View File

@ -292,7 +292,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
lora_index_to_id: list[int | None],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
):
# Make sure we don't accidentally collect outside operations
torch_xla.sync()
@ -313,7 +312,7 @@ class PunicaWrapperTPU(PunicaWrapperBase):
lora_index_to_id,
max_loras,
vocab_size,
extra_vocab_size,
0, # extra_vocab_size
"cpu",
)
self._token_lora_indices = self._pad_to_shape(

View File

@ -43,13 +43,10 @@ class PunicaWrapperXPU(PunicaWrapperBase):
lora_index_to_id: list[int | None],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
**kwargs,
):
self.is_prefill = mapping.is_prefill
self._update_base_metadata(
mapping, lora_index_to_id, max_loras, vocab_size, extra_vocab_size
)
self._update_base_metadata(mapping, lora_index_to_id, max_loras, vocab_size)
def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor:
return torch.narrow(self._token_lora_indices, 0, 0, x.size(0))

View File

@ -166,6 +166,16 @@ def parse_fine_tuned_lora_name(
raise ValueError(f"{name} is unsupported LoRA weight")
def is_base_embeddding_weights(name: str) -> bool:
# hardcoded subfixes for input & output embedding weights
input_embedding_subfix = ".embed_tokens.base_layer.weight"
output_embedding_subfix = ".lm_head.base_layer.weight"
return name.endswith(input_embedding_subfix) or name.endswith(
output_embedding_subfix
)
def is_regex_target_modules(
load_modules: str | list[str], expected_lora_modules: list[str]
) -> bool:

View File

@ -121,8 +121,7 @@ class WorkerLoRAManager:
lora_model_id=lora_request.lora_int_id,
device="cpu",
dtype=self.lora_config.lora_dtype,
target_embedding_padding=self.vocab_size
+ self.lora_config.lora_extra_vocab_size,
target_embedding_padding=self.vocab_size,
embedding_modules=self.embedding_modules,
embedding_padding_modules=self.embedding_padding_modules,
tensorizer_config_dict=lora_request.tensorizer_config_dict,
@ -143,12 +142,6 @@ class WorkerLoRAManager:
# For BadRequestError
raise e
if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
raise ValueError(
f"LoRA added vocab size {lora.extra_vocab_size} "
f"is greater than lora_extra_vocab_size "
f"{self.lora_config.lora_extra_vocab_size}."
)
return lora
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:

View File

@ -46,7 +46,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
VocabParallelEmbedding,
)
@ -261,29 +260,16 @@ class GraniteModel(nn.Module):
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.quant_config = quant_config
lora_vocab = (
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
if lora_config
else 0
)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
if get_pp_group().is_first_rank or (
config.tie_word_embeddings and get_pp_group().is_last_rank
):
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config
else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
)
else:
@ -420,28 +406,18 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = GraniteModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config
else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
@ -453,7 +429,7 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
logit_scale /= config.logits_scaling
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, config.vocab_size, scale=logit_scale
config.vocab_size, scale=logit_scale
)
else:
self.lm_head = PPMissingLayer()

View File

@ -47,7 +47,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
VocabParallelEmbedding,
)
@ -368,24 +367,18 @@ class LlamaModel(nn.Module):
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.quant_config = quant_config
lora_vocab = (
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
if lora_config
else 0
)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.vocab_size = config.vocab_size
if get_pp_group().is_first_rank or (
config.tie_word_embeddings and get_pp_group().is_last_rank
):
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config,
)
else:
@ -562,9 +555,7 @@ class LlamaForCausalLM(
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.model = self._init_model(
vllm_config=vllm_config,
@ -573,20 +564,9 @@ class LlamaForCausalLM(
)
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=(
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config
else lora_config.lora_vocab_padding_size
),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
@ -595,7 +575,7 @@ class LlamaForCausalLM(
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, config.vocab_size, logit_scale
config.vocab_size, scale=logit_scale
)
else:
self.lm_head = PPMissingLayer()

View File

@ -51,7 +51,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
VocabParallelEmbedding,
)
@ -301,23 +300,18 @@ class MixtralModel(nn.Module):
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
parallel_config = vllm_config.parallel_config
self.config = config
self.quant_config = quant_config
lora_vocab = (
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
if lora_config
else 0
)
self.vocab_size = config.vocab_size + lora_vocab
self.vocab_size = config.vocab_size
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self.enable_eplb = parallel_config.enable_eplb
@ -508,34 +502,24 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = MixtralModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config
else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, config.vocab_size
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)

View File

@ -74,5 +74,5 @@ class TeleFLMForCausalLM(LlamaForCausalLM):
self.output_mult = self.config.output_mult / self.mup_scale_factor
logit_scale = self.output_mult
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, self.config.vocab_size, logit_scale
self.config.vocab_size, scale=logit_scale
)

View File

@ -219,9 +219,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.hidden_size = model_config.get_hidden_size()
self.vocab_size = model_config.get_vocab_size()
if self.lora_config is not None:
self.vocab_size += self.lora_config.lora_extra_vocab_size
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope