mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-29 06:57:14 +08:00
[LoRA][2/2]Remove LoRA extra vocab (#28545)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
df44df0143
commit
9875be6431
@ -250,6 +250,16 @@ def olmoe_lora_files():
|
||||
return snapshot_download(repo_id="jeeejeee/olmoe-instruct-text2sql-spider")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def qwen3_lora_files():
|
||||
return snapshot_download(repo_id="charent/self_cognition_Alice")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def llama32_lora_files():
|
||||
return snapshot_download(repo_id="jeeejeee/llama32-3b-text2sql-spider")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reset_default_device():
|
||||
"""
|
||||
|
||||
@ -136,7 +136,6 @@ def populate_loras(
|
||||
id_to_index: list[int | None],
|
||||
layer: BaseLayerWithLoRA,
|
||||
layer_weights: torch.Tensor,
|
||||
generate_embeddings_tensor: int = 0,
|
||||
repeats: int = 1,
|
||||
) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]:
|
||||
"""This method populates the lora layers with lora weights.
|
||||
@ -148,8 +147,6 @@ def populate_loras(
|
||||
layer: the LoRAlayer to populate.
|
||||
layer_weights: the PyTorch tensor containing the layer's
|
||||
weights.
|
||||
generate_embeddings_tensor: whether to generate an
|
||||
embeddings tensor for each LoRA.
|
||||
repeats: must only be set for column parallel packed
|
||||
layers. Indicates the number of loras to compose
|
||||
together to create a single lora layer.
|
||||
@ -171,7 +168,6 @@ def populate_loras(
|
||||
sublora = DummyLoRAManager(layer_weights.device).init_random_lora(
|
||||
module_name=f"fake_{i}",
|
||||
weight=layer_weights,
|
||||
generate_embeddings_tensor=generate_embeddings_tensor,
|
||||
)
|
||||
sublora.lora_b = sublora.lora_b[
|
||||
(sublora_len * i) : (sublora_len * (i + 1)), :
|
||||
@ -185,7 +181,6 @@ def populate_loras(
|
||||
slot_idx,
|
||||
lora_a=lora.lora_a,
|
||||
lora_b=lora.lora_b,
|
||||
embeddings_tensor=lora.embeddings_tensor,
|
||||
)
|
||||
|
||||
lora_dict[lora_id] = lora
|
||||
@ -306,7 +301,6 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
|
||||
id_to_index,
|
||||
max_loras,
|
||||
vocab_size,
|
||||
lora_config.lora_extra_vocab_size,
|
||||
)
|
||||
|
||||
lora_result = lora_embedding(torch.cat(inputs))
|
||||
@ -344,7 +338,6 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
|
||||
id_to_index,
|
||||
max_loras,
|
||||
vocab_size,
|
||||
lora_config.lora_extra_vocab_size,
|
||||
)
|
||||
|
||||
lora_result = lora_embedding(torch.cat(inputs))
|
||||
@ -354,149 +347,6 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
|
||||
torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
# @pytest.mark.skip(
|
||||
# reason="Fails when loras are in any slot other than the first.")
|
||||
@pytest.mark.parametrize("num_loras", [1, 2, 4])
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
|
||||
@pytest.mark.parametrize("stage", STAGES)
|
||||
def test_embeddings_with_new_embeddings(
|
||||
dist_init, num_loras, device, vocab_size, stage
|
||||
) -> None:
|
||||
if current_platform.is_cuda_alike():
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
|
||||
assert check_punica_wrapper(punica_wrapper)
|
||||
lora_config = LoRAConfig(
|
||||
max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16
|
||||
)
|
||||
|
||||
def create_random_embedding_layer():
|
||||
embedding = VocabParallelEmbedding(vocab_size, 256)
|
||||
embedding_data = torch.rand_like(embedding.weight.data)
|
||||
embedding.weight.data = embedding_data
|
||||
embedding.weight.data[vocab_size:, :] = 0
|
||||
expanded_embedding = VocabParallelEmbedding(
|
||||
vocab_size + lora_config.lora_extra_vocab_size * max_loras,
|
||||
256,
|
||||
org_num_embeddings=vocab_size,
|
||||
)
|
||||
expanded_embedding.weight.data[:vocab_size, :] = embedding_data
|
||||
# We need to deepcopy the embedding as it will be modified
|
||||
# in place
|
||||
lora_embedding = VocabParallelEmbeddingWithLoRA(deepcopy(expanded_embedding))
|
||||
lora_embedding.create_lora_weights(max_loras, lora_config)
|
||||
|
||||
return expanded_embedding, lora_embedding
|
||||
|
||||
for i in range(NUM_RANDOM_SEEDS):
|
||||
set_random_seed(i)
|
||||
|
||||
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||
expanded_embedding, lora_embedding = create_random_embedding_layer()
|
||||
lora_dict, _ = populate_loras(
|
||||
id_to_index,
|
||||
layer=lora_embedding,
|
||||
layer_weights=torch.zeros(
|
||||
(256, vocab_size + lora_config.lora_extra_vocab_size)
|
||||
),
|
||||
generate_embeddings_tensor=256,
|
||||
)
|
||||
|
||||
lora_embedding.set_mapping(punica_wrapper)
|
||||
# All embeddings tensors have the same shape.
|
||||
embeddings_tensors = [
|
||||
lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys())
|
||||
]
|
||||
embeddings_tensor_len = embeddings_tensors[0].shape[0]
|
||||
|
||||
# Add empty embeddings_tensors for unoccupied lora slots.
|
||||
for _ in range(max_loras - len(embeddings_tensors)):
|
||||
embeddings_tensors.append(torch.zeros(embeddings_tensors[0].shape))
|
||||
|
||||
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
||||
active_lora_ids=list(lora_dict.keys()),
|
||||
num_inputs=num_loras * 3,
|
||||
input_size=(200,),
|
||||
input_range=(1, vocab_size),
|
||||
device=device,
|
||||
)
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
|
||||
punica_wrapper.update_metadata(
|
||||
lora_mapping,
|
||||
id_to_index,
|
||||
max_loras,
|
||||
vocab_size,
|
||||
lora_config.lora_extra_vocab_size,
|
||||
)
|
||||
original_inputs = deepcopy(inputs)
|
||||
|
||||
# Force some of the inputs to be in the extended embeddings range
|
||||
# to guarantee that their behavior is tested.
|
||||
for input_, original_input_, lora_id in zip(
|
||||
inputs, original_inputs, prompt_mapping
|
||||
):
|
||||
embedding_id = lora_id - 1
|
||||
input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len)
|
||||
original_input_[-1] = vocab_size
|
||||
input_[-2] = vocab_size + ((embedding_id + 1) * embeddings_tensor_len - 1)
|
||||
original_input_[-2] = vocab_size + embeddings_tensor_len - 1
|
||||
|
||||
expanded_embedding.weight[
|
||||
vocab_size : vocab_size + (embeddings_tensor_len * max_loras)
|
||||
] = torch.cat(embeddings_tensors)
|
||||
|
||||
lora_result = lora_embedding(torch.cat(original_inputs))
|
||||
|
||||
expected_results: list[torch.Tensor] = []
|
||||
for input_, original_input_, lora_id in zip(
|
||||
inputs, original_inputs, prompt_mapping
|
||||
):
|
||||
lora = lora_dict[lora_id]
|
||||
result = expanded_embedding(input_)
|
||||
after_a = F.embedding(
|
||||
original_input_,
|
||||
lora.lora_a.T,
|
||||
)
|
||||
result += after_a @ lora.lora_b.T
|
||||
expected_results.append(result)
|
||||
expected_result = torch.cat(expected_results)
|
||||
|
||||
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||
torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
|
||||
|
||||
# Check that resetting the lora weights succeeds
|
||||
|
||||
for slot_idx in range(max_loras):
|
||||
lora_embedding.reset_lora(slot_idx)
|
||||
|
||||
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
||||
active_lora_ids=[0],
|
||||
num_inputs=num_loras * 3,
|
||||
input_size=(200,),
|
||||
input_range=(1, vocab_size),
|
||||
device=device,
|
||||
)
|
||||
original_inputs = deepcopy(inputs)
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
|
||||
punica_wrapper.update_metadata(
|
||||
lora_mapping,
|
||||
id_to_index,
|
||||
max_loras,
|
||||
vocab_size,
|
||||
lora_config.lora_extra_vocab_size,
|
||||
)
|
||||
lora_result = lora_embedding(torch.cat(original_inputs))
|
||||
expected_result = expanded_embedding(torch.cat(inputs))
|
||||
|
||||
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||
torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("num_loras", [1, 2, 4])
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@ -518,16 +368,13 @@ def test_lm_head_logits_processor(
|
||||
|
||||
def _pretest():
|
||||
linear = ParallelLMHead(
|
||||
vocab_size + lora_config.lora_extra_vocab_size,
|
||||
1024,
|
||||
vocab_size,
|
||||
num_embeddings=vocab_size,
|
||||
embedding_dim=1024,
|
||||
params_dtype=torch.float16,
|
||||
)
|
||||
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||
linear.weight.data[:, vocab_size:] = 0
|
||||
logits_processor = LogitsProcessor(
|
||||
vocab_size + lora_config.lora_extra_vocab_size, vocab_size
|
||||
)
|
||||
logits_processor = LogitsProcessor(vocab_size)
|
||||
lora_logits_processor = LogitsProcessorWithLoRA(
|
||||
logits_processor, 1024, linear.weight.dtype, linear.weight.device, None
|
||||
)
|
||||
@ -541,15 +388,12 @@ def test_lm_head_logits_processor(
|
||||
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||
linear, logits_processor, lora_logits_processor = _pretest()
|
||||
lora_logits_processor.set_mapping(punica_wrapper)
|
||||
# NOTE: all the generated loras share the same embeddings tensor.
|
||||
|
||||
lora_dict, _ = populate_loras(
|
||||
id_to_index,
|
||||
layer=lora_logits_processor,
|
||||
layer_weights=linear.weight,
|
||||
generate_embeddings_tensor=1024,
|
||||
)
|
||||
embeddings_tensor = list(lora_dict.values())[0].embeddings_tensor
|
||||
embeddings_tensor_len = embeddings_tensor.shape[0]
|
||||
|
||||
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
||||
active_lora_ids=list(lora_dict.keys()),
|
||||
@ -565,7 +409,6 @@ def test_lm_head_logits_processor(
|
||||
id_to_index,
|
||||
max_loras,
|
||||
vocab_size,
|
||||
lora_config.lora_extra_vocab_size,
|
||||
)
|
||||
input_ = torch.rand(20, 1024)
|
||||
|
||||
@ -575,23 +418,16 @@ def test_lm_head_logits_processor(
|
||||
|
||||
original_lm_head = deepcopy(linear)
|
||||
|
||||
linear.weight[
|
||||
logits_processor.org_vocab_size : logits_processor.org_vocab_size
|
||||
+ embeddings_tensor_len
|
||||
] = embeddings_tensor
|
||||
|
||||
logits_processor.org_vocab_size = vocab_size + lora_config.lora_extra_vocab_size
|
||||
expected_results: list[torch.Tensor] = []
|
||||
for input_, lora_id in zip(inputs, prompt_mapping):
|
||||
lora = lora_dict[lora_id]
|
||||
result = logits_processor._get_logits(
|
||||
hidden_states=input_, lm_head=linear, embedding_bias=None
|
||||
)
|
||||
result[:, vocab_size + embeddings_tensor_len :] = float("-inf")
|
||||
|
||||
result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
|
||||
expected_results.append(result)
|
||||
expected_result = torch.cat(expected_results)
|
||||
logits_processor.org_vocab_size = vocab_size
|
||||
|
||||
# Check that resetting the lora weights succeeds
|
||||
|
||||
@ -612,7 +448,6 @@ def test_lm_head_logits_processor(
|
||||
id_to_index,
|
||||
max_loras,
|
||||
vocab_size,
|
||||
lora_config.lora_extra_vocab_size,
|
||||
)
|
||||
|
||||
lora_result = lora_logits_processor._get_logits(
|
||||
@ -694,7 +529,6 @@ def test_linear_replicated(
|
||||
id_to_index,
|
||||
max_loras,
|
||||
512,
|
||||
lora_config.lora_extra_vocab_size,
|
||||
)
|
||||
|
||||
lora_result = lora_linear(torch.cat(inputs))[0]
|
||||
@ -726,7 +560,10 @@ def test_linear_replicated(
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
|
||||
|
||||
punica_wrapper.update_metadata(
|
||||
lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size
|
||||
lora_mapping,
|
||||
id_to_index,
|
||||
max_loras,
|
||||
512,
|
||||
)
|
||||
|
||||
lora_result = lora_linear(torch.cat(inputs))[0]
|
||||
@ -817,7 +654,6 @@ def test_linear_parallel(
|
||||
id_to_index,
|
||||
max_loras,
|
||||
512,
|
||||
lora_config.lora_extra_vocab_size,
|
||||
)
|
||||
|
||||
lora_result = lora_linear(torch.cat(inputs))[0]
|
||||
@ -849,7 +685,10 @@ def test_linear_parallel(
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
|
||||
|
||||
punica_wrapper.update_metadata(
|
||||
lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size
|
||||
lora_mapping,
|
||||
id_to_index,
|
||||
max_loras,
|
||||
512,
|
||||
)
|
||||
|
||||
lora_result = lora_linear(torch.cat(inputs))[0]
|
||||
@ -963,7 +802,6 @@ def test_column_parallel_packed(
|
||||
id_to_index,
|
||||
max_loras,
|
||||
512,
|
||||
lora_config.lora_extra_vocab_size,
|
||||
)
|
||||
|
||||
lora_result = lora_linear(torch.cat(inputs))[0]
|
||||
@ -1000,7 +838,6 @@ def test_column_parallel_packed(
|
||||
id_to_index,
|
||||
max_loras,
|
||||
512,
|
||||
lora_config.lora_extra_vocab_size,
|
||||
)
|
||||
|
||||
lora_result = lora_linear(torch.cat(inputs))[0]
|
||||
|
||||
@ -13,17 +13,27 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
|
||||
from ..utils import VLLM_PATH, create_new_process_for_each_test, multi_gpu_test
|
||||
|
||||
MODEL_PATH = "meta-llama/Llama-2-7b-hf"
|
||||
PROMPT_TEMPLATE = """<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.
|
||||
"
|
||||
##Instruction:
|
||||
candidate_poll contains tables such as candidate, people. Table candidate has columns such as Candidate_ID, People_ID, Poll_Source, Date, Support_rate, Consider_rate, Oppose_rate, Unsure_rate. Candidate_ID is the primary key.
|
||||
Table people has columns such as People_ID, Sex, Name, Date_of_Birth, Height, Weight. People_ID is the primary key.
|
||||
The People_ID of candidate is the foreign key of People_ID of people.
|
||||
###Input:
|
||||
{context}
|
||||
###Response:<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
""" # noqa: E501
|
||||
|
||||
EXPECTED_LORA_OUTPUT = [
|
||||
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501
|
||||
" SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ",
|
||||
" SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ", # noqa: E501
|
||||
" SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", # noqa: E501
|
||||
" SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ",
|
||||
" SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' ", # noqa: E501
|
||||
"SELECT count(*) FROM candidate",
|
||||
"SELECT count(*) FROM candidate",
|
||||
"SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501
|
||||
"SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501
|
||||
]
|
||||
|
||||
MODEL_PATH = "meta-llama/Llama-3.2-3B-Instruct"
|
||||
|
||||
|
||||
def do_sample(
|
||||
llm: vllm.LLM,
|
||||
@ -32,18 +42,19 @@ def do_sample(
|
||||
tensorizer_config_dict: dict | None = None,
|
||||
) -> list[str]:
|
||||
prompts = [
|
||||
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
|
||||
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
|
||||
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]", # noqa: E501
|
||||
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501
|
||||
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", # noqa: E501
|
||||
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]", # noqa: E501
|
||||
PROMPT_TEMPLATE.format(context="How many candidates are there?"),
|
||||
PROMPT_TEMPLATE.format(context="Count the number of candidates."),
|
||||
PROMPT_TEMPLATE.format(
|
||||
context="Which poll resource provided the most number of candidate information?" # noqa: E501
|
||||
),
|
||||
PROMPT_TEMPLATE.format(
|
||||
context="Return the poll resource associated with the most candidates."
|
||||
),
|
||||
]
|
||||
|
||||
sampling_params = vllm.SamplingParams(
|
||||
temperature=0, max_tokens=256, skip_special_tokens=False, stop=["[/assistant]"]
|
||||
temperature=0, max_tokens=64, stop=["<|im_end|>"]
|
||||
)
|
||||
|
||||
if tensorizer_config_dict is not None:
|
||||
outputs = llm.generate(
|
||||
prompts,
|
||||
@ -75,13 +86,15 @@ def do_sample(
|
||||
return generated_texts
|
||||
|
||||
|
||||
def generate_and_test(llm, sql_lora_files, tensorizer_config_dict: dict | None = None):
|
||||
def generate_and_test(
|
||||
llm, llama32_lora_files, tensorizer_config_dict: dict | None = None
|
||||
):
|
||||
print("lora adapter created")
|
||||
print("lora 1")
|
||||
assert (
|
||||
do_sample(
|
||||
llm,
|
||||
sql_lora_files,
|
||||
llama32_lora_files,
|
||||
tensorizer_config_dict=tensorizer_config_dict,
|
||||
lora_id=1,
|
||||
)
|
||||
@ -92,7 +105,7 @@ def generate_and_test(llm, sql_lora_files, tensorizer_config_dict: dict | None =
|
||||
assert (
|
||||
do_sample(
|
||||
llm,
|
||||
sql_lora_files,
|
||||
llama32_lora_files,
|
||||
tensorizer_config_dict=tensorizer_config_dict,
|
||||
lora_id=2,
|
||||
)
|
||||
@ -104,51 +117,52 @@ def generate_and_test(llm, sql_lora_files, tensorizer_config_dict: dict | None =
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.parametrize("cudagraph_specialize_lora", [True, False])
|
||||
def test_llama_lora(sql_lora_files, cudagraph_specialize_lora: bool):
|
||||
def test_llama_lora(llama32_lora_files, cudagraph_specialize_lora: bool):
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
tokenizer=sql_lora_files,
|
||||
enable_lora=True,
|
||||
# also test odd max_num_seqs
|
||||
max_num_seqs=13,
|
||||
max_num_seqs=7,
|
||||
max_model_len=1024,
|
||||
max_loras=4,
|
||||
compilation_config=vllm.config.CompilationConfig(
|
||||
cudagraph_specialize_lora=cudagraph_specialize_lora,
|
||||
),
|
||||
)
|
||||
generate_and_test(llm, sql_lora_files)
|
||||
generate_and_test(llm, llama32_lora_files)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=4)
|
||||
def test_llama_lora_tp4(sql_lora_files):
|
||||
def test_llama_lora_tp4(llama32_lora_files):
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
tokenizer=sql_lora_files,
|
||||
enable_lora=True,
|
||||
max_num_seqs=16,
|
||||
max_num_seqs=7,
|
||||
max_model_len=1024,
|
||||
max_loras=4,
|
||||
tensor_parallel_size=4,
|
||||
)
|
||||
generate_and_test(llm, sql_lora_files)
|
||||
generate_and_test(llm, llama32_lora_files)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=4)
|
||||
def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):
|
||||
def test_llama_lora_tp4_fully_sharded_loras(llama32_lora_files):
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
tokenizer=sql_lora_files,
|
||||
enable_lora=True,
|
||||
max_num_seqs=16,
|
||||
max_num_seqs=8,
|
||||
max_loras=4,
|
||||
max_model_len=1024,
|
||||
tensor_parallel_size=4,
|
||||
fully_sharded_loras=True,
|
||||
)
|
||||
generate_and_test(llm, sql_lora_files)
|
||||
generate_and_test(llm, llama32_lora_files)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
def test_tp2_serialize_and_deserialize_lora(
|
||||
tmp_path, sql_lora_files, sql_lora_huggingface_id
|
||||
tmp_path,
|
||||
llama32_lora_files,
|
||||
):
|
||||
# Run the tensorizing of the LoRA adapter and the model in a subprocess
|
||||
# to guarantee cleanup
|
||||
@ -157,7 +171,7 @@ def test_tp2_serialize_and_deserialize_lora(
|
||||
model_name = "model-rank-%03d.tensors"
|
||||
|
||||
model_ref = MODEL_PATH
|
||||
lora_path = sql_lora_huggingface_id
|
||||
lora_path = llama32_lora_files
|
||||
suffix = "test"
|
||||
try:
|
||||
result = subprocess.run(
|
||||
@ -195,12 +209,12 @@ def test_tp2_serialize_and_deserialize_lora(
|
||||
|
||||
loaded_llm = LLM(
|
||||
model=model_ref,
|
||||
tokenizer=sql_lora_files,
|
||||
load_format="tensorizer",
|
||||
enable_lora=True,
|
||||
enforce_eager=True,
|
||||
model_loader_extra_config=tensorizer_config,
|
||||
max_num_seqs=13,
|
||||
max_num_seqs=7,
|
||||
max_model_len=1024,
|
||||
tensor_parallel_size=2,
|
||||
max_loras=2,
|
||||
)
|
||||
@ -211,7 +225,7 @@ def test_tp2_serialize_and_deserialize_lora(
|
||||
print("lora 1")
|
||||
assert (
|
||||
do_sample(
|
||||
loaded_llm, sql_lora_files, tensorizer_config_dict=tc_as_dict, lora_id=1
|
||||
loaded_llm, llama32_lora_files, tensorizer_config_dict=tc_as_dict, lora_id=1
|
||||
)
|
||||
== EXPECTED_LORA_OUTPUT
|
||||
)
|
||||
|
||||
@ -13,8 +13,8 @@ from vllm.entrypoints.openai.api_server import (
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.v1.engine.llm_engine import LLMEngine
|
||||
|
||||
MODEL_PATH = "meta-llama/Llama-2-7b-hf"
|
||||
LORA_MODULE_PATH = "yard1/llama-2-7b-sql-lora-test"
|
||||
MODEL_PATH = "Qwen/Qwen3-0.6B"
|
||||
LORA_MODULE_PATH = "charent/self_cognition_Alice"
|
||||
LORA_RANK = 8
|
||||
|
||||
|
||||
|
||||
@ -48,9 +48,6 @@ DEFAULT_DTYPE = torch.get_default_dtype()
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_from_lora_tensors(sql_lora_files, device):
|
||||
tensors = load_file(os.path.join(sql_lora_files, "adapter_model.safetensors"))
|
||||
new_embeddings = load_file(
|
||||
os.path.join(sql_lora_files, "new_embeddings.safetensors")
|
||||
)
|
||||
|
||||
peft_helper = PEFTHelper.from_local_dir(
|
||||
sql_lora_files, max_position_embeddings=4096
|
||||
@ -60,7 +57,6 @@ def test_from_lora_tensors(sql_lora_files, device):
|
||||
tensors,
|
||||
peft_helper=peft_helper,
|
||||
device=device,
|
||||
embeddings=new_embeddings,
|
||||
embedding_modules=EMBEDDING_MODULES,
|
||||
embedding_padding_modules=EMBEDDING_PADDING_MODULES,
|
||||
)
|
||||
@ -76,18 +72,6 @@ def test_from_lora_tensors(sql_lora_files, device):
|
||||
f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
|
||||
)
|
||||
assert lora.lora_a.shape[0] == 8
|
||||
embeddings_module = next(
|
||||
(k for k in EMBEDDING_MODULES if k in module_name), None
|
||||
)
|
||||
if embeddings_module:
|
||||
assert torch.equal(
|
||||
lora.embeddings_tensor,
|
||||
new_embeddings[EMBEDDING_MODULES[embeddings_module]].to(
|
||||
device=lora.embeddings_tensor.device
|
||||
),
|
||||
)
|
||||
else:
|
||||
assert lora.embeddings_tensor is None
|
||||
|
||||
|
||||
def create_lora(
|
||||
@ -552,9 +536,7 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, tmp_path
|
||||
worker_adapter_manager = WorkerLoRAManager(
|
||||
vllm_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES
|
||||
)
|
||||
worker_adapter_manager.vocab_size = (
|
||||
dummy_model_gate_up.unpadded_vocab_size - lora_config.lora_extra_vocab_size
|
||||
)
|
||||
worker_adapter_manager.vocab_size = dummy_model_gate_up.unpadded_vocab_size
|
||||
worker_adapter_manager.create_lora_manager(dummy_model_gate_up)
|
||||
|
||||
dummy_lora_files = f"{tmp_path}/lora_adapter"
|
||||
|
||||
@ -20,11 +20,12 @@ from vllm.lora.models import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.v1.worker.gpu_worker import Worker
|
||||
|
||||
MODEL_PATH = "Qwen/Qwen3-0.6B"
|
||||
NUM_LORAS = 16
|
||||
|
||||
|
||||
@patch.dict(os.environ, {"RANK": "0"})
|
||||
def test_worker_apply_lora(sql_lora_files):
|
||||
def test_worker_apply_lora(qwen3_lora_files):
|
||||
def set_active_loras(worker: Worker, lora_requests: list[LoRARequest]):
|
||||
lora_mapping = LoRAMapping([], [])
|
||||
|
||||
@ -34,9 +35,10 @@ def test_worker_apply_lora(sql_lora_files):
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
model_config=ModelConfig(
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
MODEL_PATH,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
max_model_len=127,
|
||||
enforce_eager=True,
|
||||
),
|
||||
load_config=LoadConfig(
|
||||
@ -73,7 +75,7 @@ def test_worker_apply_lora(sql_lora_files):
|
||||
assert worker.list_loras() == set()
|
||||
|
||||
lora_requests = [
|
||||
LoRARequest(str(i + 1), i + 1, sql_lora_files) for i in range(NUM_LORAS)
|
||||
LoRARequest(str(i + 1), i + 1, qwen3_lora_files) for i in range(NUM_LORAS)
|
||||
]
|
||||
|
||||
set_active_loras(worker, lora_requests)
|
||||
|
||||
@ -28,7 +28,6 @@ class DummyLoRAManager:
|
||||
module_name: str,
|
||||
weight: torch.Tensor,
|
||||
rank: int = 8,
|
||||
generate_embeddings_tensor: int = 0,
|
||||
):
|
||||
lora = LoRALayerWeights(
|
||||
module_name,
|
||||
@ -41,13 +40,6 @@ class DummyLoRAManager:
|
||||
[weight.shape[0], rank], dtype=weight.dtype, device=self._device
|
||||
),
|
||||
)
|
||||
if generate_embeddings_tensor:
|
||||
lora.embeddings_tensor = torch.rand(
|
||||
5,
|
||||
generate_embeddings_tensor,
|
||||
dtype=weight.dtype,
|
||||
device=self._device,
|
||||
)
|
||||
self.set_module_lora(module_name, lora)
|
||||
|
||||
return lora
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Literal
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import torch
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
@ -11,7 +11,6 @@ from typing_extensions import Self
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
@ -46,19 +45,6 @@ class LoRAConfig:
|
||||
`max_loras`."""
|
||||
lora_dtype: torch.dtype | LoRADType = "auto"
|
||||
"""Data type for LoRA. If auto, will default to base model dtype."""
|
||||
lora_extra_vocab_size: LoRAExtraVocabSize = Field(
|
||||
default=256,
|
||||
deprecated=(
|
||||
"`lora_extra_vocab_size` is deprecated and will be removed "
|
||||
"in v0.12.0. Additional vocabulary support for "
|
||||
"LoRA adapters is being phased out."
|
||||
),
|
||||
)
|
||||
"""(Deprecated) Maximum size of extra vocabulary that can be present in a
|
||||
LoRA adapter. Will be removed in v0.12.0."""
|
||||
lora_vocab_padding_size: ClassVar[int] = (
|
||||
current_platform.get_lora_vocab_padding_size()
|
||||
)
|
||||
default_mm_loras: dict[str, str] | None = None
|
||||
"""Dictionary mapping specific modalities to LoRA model paths; this field
|
||||
is only applicable to multimodal models and should be leveraged when a
|
||||
@ -87,8 +73,6 @@ class LoRAConfig:
|
||||
factors.append(self.max_loras)
|
||||
factors.append(self.fully_sharded_loras)
|
||||
factors.append(self.lora_dtype)
|
||||
factors.append(self.lora_extra_vocab_size)
|
||||
factors.append(self.lora_vocab_padding_size)
|
||||
|
||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@ -484,7 +484,6 @@ class EngineArgs:
|
||||
fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
|
||||
max_cpu_loras: int | None = LoRAConfig.max_cpu_loras
|
||||
lora_dtype: str | torch.dtype | None = LoRAConfig.lora_dtype
|
||||
lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size
|
||||
|
||||
ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
|
||||
num_gpu_blocks_override: int | None = CacheConfig.num_gpu_blocks_override
|
||||
@ -1011,9 +1010,6 @@ class EngineArgs:
|
||||
)
|
||||
lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"])
|
||||
lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"])
|
||||
lora_group.add_argument(
|
||||
"--lora-extra-vocab-size", **lora_kwargs["lora_extra_vocab_size"]
|
||||
)
|
||||
lora_group.add_argument(
|
||||
"--lora-dtype",
|
||||
**lora_kwargs["lora_dtype"],
|
||||
@ -1680,7 +1676,6 @@ class EngineArgs:
|
||||
max_loras=self.max_loras,
|
||||
default_mm_loras=self.default_mm_loras,
|
||||
fully_sharded_loras=self.fully_sharded_loras,
|
||||
lora_extra_vocab_size=self.lora_extra_vocab_size,
|
||||
lora_dtype=self.lora_dtype,
|
||||
max_cpu_loras=self.max_cpu_loras
|
||||
if self.max_cpu_loras and self.max_cpu_loras > 0
|
||||
|
||||
@ -44,7 +44,6 @@ class BaseLayerWithLoRA(nn.Module):
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: torch.Tensor | None,
|
||||
):
|
||||
"""Overwrites lora tensors at index."""
|
||||
...
|
||||
|
||||
@ -96,7 +96,6 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: torch.Tensor | None,
|
||||
):
|
||||
# Except for QKVParallelLinearWithLoRA and
|
||||
# MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
|
||||
|
||||
@ -248,7 +248,6 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: torch.Tensor | None,
|
||||
):
|
||||
self.reset_lora(index)
|
||||
|
||||
|
||||
@ -406,8 +406,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: torch.Tensor | None,
|
||||
bias: torch.Tensor | None = None,
|
||||
):
|
||||
"""Overwrites lora tensors at index."""
|
||||
self.reset_lora(index)
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -108,22 +107,13 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||
(
|
||||
max_loras,
|
||||
1,
|
||||
# Pad for kernel compatibility
|
||||
math.ceil(
|
||||
self.base_layer.vocab_size / lora_config.lora_vocab_padding_size
|
||||
)
|
||||
* lora_config.lora_vocab_padding_size,
|
||||
self.base_layer.vocab_size,
|
||||
lora_config.max_lora_rank,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.embeddings_tensors = torch.full(
|
||||
(max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
|
||||
fill_value=float("-inf"),
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if self.sharded_to_full_mapping is not None:
|
||||
self.sharded_to_full_mapping_gpu = torch.tensor(
|
||||
self.sharded_to_full_mapping, device=self.device, dtype=torch.long
|
||||
@ -134,14 +124,12 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||
def reset_lora(self, index: int):
|
||||
self.lora_a_stacked[index] = 0
|
||||
self.lora_b_stacked[index] = 0
|
||||
self.embeddings_tensors[index] = float("-inf")
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: torch.Tensor | None,
|
||||
):
|
||||
self.reset_lora(index)
|
||||
self.lora_a_stacked[index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_(
|
||||
@ -150,12 +138,6 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||
self.lora_b_stacked[index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_(
|
||||
lora_b, non_blocking=True
|
||||
)
|
||||
if embeddings_tensor is not None:
|
||||
self.embeddings_tensors[
|
||||
index,
|
||||
: embeddings_tensor.shape[0],
|
||||
: embeddings_tensor.shape[1],
|
||||
] = embeddings_tensor
|
||||
|
||||
def _get_logits(
|
||||
self,
|
||||
@ -193,39 +175,6 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||
# token_id: [0, 1, 2, 3, 4, 5, -1, -1]
|
||||
logits = logits[:, self.sharded_to_full_mapping_gpu]
|
||||
|
||||
lora_logits = torch.empty(
|
||||
self.embeddings_tensors.shape[0] + 1,
|
||||
self.embeddings_tensors.shape[1],
|
||||
hidden_states.shape[0],
|
||||
dtype=self.embeddings_tensors.dtype,
|
||||
device=self.embeddings_tensors.device,
|
||||
)
|
||||
torch.matmul(self.embeddings_tensors, hidden_states.T, out=lora_logits[:-1])
|
||||
|
||||
neg_inf, pos_inf = current_platform.get_infinity_values(lora_logits.dtype)
|
||||
|
||||
lora_logits[-1] = neg_inf
|
||||
lora_logits = lora_logits.mT
|
||||
indices_padded = self.punica_wrapper.sampler_indices_padded
|
||||
|
||||
if current_platform.is_tpu() or current_platform.is_xpu():
|
||||
indices_padded = indices_padded[: logits.size(0)]
|
||||
|
||||
lora_logits = (
|
||||
lora_logits.reshape(
|
||||
lora_logits.shape[0] * lora_logits.shape[1],
|
||||
lora_logits.shape[2],
|
||||
)
|
||||
.index_select(0, indices_padded)
|
||||
.nan_to_num_(nan=neg_inf, posinf=pos_inf, neginf=neg_inf)
|
||||
)
|
||||
|
||||
logits[
|
||||
:,
|
||||
self.base_layer.org_vocab_size : self.base_layer.org_vocab_size
|
||||
+ lora_logits.shape[1],
|
||||
] = lora_logits
|
||||
|
||||
lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_logits(
|
||||
logits, hidden_states, self.lora_a_stacked, self.lora_b_stacked, 1.0
|
||||
)
|
||||
|
||||
@ -46,19 +46,10 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
self.embeddings_slice = None
|
||||
self.embeddings_weights = None
|
||||
|
||||
self.embeddings_tensors = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
lora_config.lora_extra_vocab_size,
|
||||
self.base_layer.embedding_dim,
|
||||
),
|
||||
dtype=self.base_layer.weight.dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
)
|
||||
self.lora_a_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.org_vocab_size + lora_config.lora_extra_vocab_size,
|
||||
self.base_layer.org_vocab_size,
|
||||
lora_config.max_lora_rank,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
@ -82,14 +73,12 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
def reset_lora(self, index: int):
|
||||
self.lora_a_stacked[index] = 0
|
||||
self.lora_b_stacked[index] = 0
|
||||
self.embeddings_tensors[index] = 0
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: torch.Tensor | None,
|
||||
):
|
||||
self.reset_lora(index)
|
||||
# NOTE self.lora_a_stacked is row-major, and lora_a is col-major,
|
||||
@ -100,36 +89,18 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
self.lora_b_stacked[index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_(
|
||||
lora_b, non_blocking=True
|
||||
)
|
||||
if embeddings_tensor is not None:
|
||||
self.embeddings_tensors[
|
||||
index,
|
||||
: embeddings_tensor.shape[0],
|
||||
: embeddings_tensor.shape[1],
|
||||
].copy_(embeddings_tensor, non_blocking=True)
|
||||
if self.embeddings_slice is not None:
|
||||
# TODO(yard1): Optimize this copy, we don't need to copy
|
||||
# everything, just the modified part
|
||||
embeddings = self.embeddings_tensors.view(
|
||||
self.embeddings_tensors.shape[0] * self.embeddings_tensors.shape[1],
|
||||
self.embeddings_tensors.shape[2],
|
||||
)[self.embeddings_slice[0] : self.embeddings_slice[1]]
|
||||
assert self.embeddings_weights is not None
|
||||
self.embeddings_weights[: embeddings.shape[0]].copy_(embeddings)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1, 1, 0)
|
||||
|
||||
# NB: Don't use torch.narrow here. torch.narrow triggers some
|
||||
# Dynamic Shape specialization in torch.compile
|
||||
num_tokens = x.shape[0]
|
||||
indices_1 = self.punica_wrapper._embeddings_indices[1][:num_tokens]
|
||||
indices_0 = self.punica_wrapper._embeddings_indices[0][:num_tokens]
|
||||
|
||||
full_lora_a_embeddings = F.embedding(
|
||||
x + indices_1,
|
||||
self.lora_a_stacked_2d,
|
||||
)
|
||||
full_output = self.base_layer.forward(x + (indices_0 * added_tokens_mask))
|
||||
full_output = self.base_layer.forward(x)
|
||||
|
||||
full_output_org = full_output
|
||||
if full_output.ndim == 3:
|
||||
|
||||
@ -21,7 +21,6 @@ class LoRALayerWeights:
|
||||
lora_alpha: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: torch.Tensor | None = None,
|
||||
scaling: float | None = None,
|
||||
) -> None:
|
||||
self.module_name = module_name
|
||||
@ -29,7 +28,6 @@ class LoRALayerWeights:
|
||||
self.lora_alpha = lora_alpha
|
||||
self.lora_a = lora_a
|
||||
self.lora_b = lora_b
|
||||
self.embeddings_tensor = embeddings_tensor
|
||||
|
||||
if scaling is None:
|
||||
self.scaling = self.lora_alpha / self.rank
|
||||
@ -56,18 +54,11 @@ class LoRALayerWeights:
|
||||
def is_packed(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def extra_vocab_size(self) -> int:
|
||||
return (
|
||||
self.embeddings_tensor.shape[0] if self.embeddings_tensor is not None else 0
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
module_name: str,
|
||||
peft_helper: PEFTHelper,
|
||||
embeddings_tensor: torch.Tensor | None = None,
|
||||
) -> "LoRALayerWeights":
|
||||
# lora_a and lora_b are set to None for config-based construction
|
||||
return cls(
|
||||
@ -76,7 +67,6 @@ class LoRALayerWeights:
|
||||
peft_helper.lora_alpha,
|
||||
None,
|
||||
None,
|
||||
embeddings_tensor,
|
||||
peft_helper.vllm_lora_scaling_factor,
|
||||
)
|
||||
|
||||
@ -89,7 +79,6 @@ class LoRALayerWeights:
|
||||
rank: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.types.Device,
|
||||
embeddings_tensor_dim: int | None = None,
|
||||
) -> "LoRALayerWeights":
|
||||
pin_memory = str(device) == "cpu" and is_pin_memory_available()
|
||||
lora_a = torch.zeros(
|
||||
@ -99,24 +88,12 @@ class LoRALayerWeights:
|
||||
[output_dim, rank], dtype=dtype, device=device, pin_memory=pin_memory
|
||||
)
|
||||
|
||||
embeddings_tensor = (
|
||||
torch.rand(
|
||||
10,
|
||||
embeddings_tensor_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
if embeddings_tensor_dim
|
||||
else None
|
||||
)
|
||||
return cls(
|
||||
module_name,
|
||||
rank=rank,
|
||||
lora_alpha=1,
|
||||
lora_a=lora_a,
|
||||
lora_b=lora_b,
|
||||
embeddings_tensor=embeddings_tensor,
|
||||
)
|
||||
|
||||
|
||||
@ -139,7 +116,6 @@ class PackedLoRALayerWeights(LoRALayerWeights):
|
||||
lora_a=lora_a,
|
||||
lora_b=lora_b,
|
||||
scaling=scaling, # type: ignore
|
||||
embeddings_tensor=None,
|
||||
)
|
||||
self.lora_alphas = lora_alphas
|
||||
if scaling is None:
|
||||
|
||||
@ -21,6 +21,7 @@ from vllm.lora.utils import (
|
||||
from_layer,
|
||||
from_layer_logits_processor,
|
||||
get_supported_lora_modules,
|
||||
is_base_embeddding_weights,
|
||||
is_regex_target_modules,
|
||||
parse_fine_tuned_lora_name,
|
||||
process_packed_modules_mapping,
|
||||
@ -93,14 +94,6 @@ class LoRAModel:
|
||||
loras=self.loras.copy(),
|
||||
)
|
||||
|
||||
@property
|
||||
def extra_vocab_size(self) -> int:
|
||||
return (
|
||||
max(lora.extra_vocab_size for lora in self.loras.values())
|
||||
if self.loras
|
||||
else 0
|
||||
)
|
||||
|
||||
def get_lora(self, module_name: str) -> LoRALayerWeights | None:
|
||||
"""Get LoRA for a given module by name"""
|
||||
return self.loras.get(module_name, None)
|
||||
@ -117,7 +110,6 @@ class LoRAModel:
|
||||
peft_helper: PEFTHelper,
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype | None = None,
|
||||
embeddings: dict[str, torch.Tensor] | None = None,
|
||||
target_embedding_padding: int | None = None,
|
||||
embedding_modules: dict[str, str] | None = None,
|
||||
embedding_padding_modules: list[str] | None = None,
|
||||
@ -127,24 +119,14 @@ class LoRAModel:
|
||||
pin_memory = str(device) == "cpu" and is_pin_memory_available()
|
||||
loras: dict[str, LoRALayerWeights] = {}
|
||||
for tensor_name, tensor in tensors.items():
|
||||
if is_base_embeddding_weights(tensor_name):
|
||||
continue
|
||||
module_name, is_lora_a = parse_fine_tuned_lora_name(
|
||||
tensor_name, weights_mapper
|
||||
)
|
||||
if module_name not in loras:
|
||||
lora_embeddings_tensor = None
|
||||
if embeddings:
|
||||
assert embedding_modules is not None
|
||||
embeddings_module = next(
|
||||
(k for k in embedding_modules if k in module_name), None
|
||||
)
|
||||
if embeddings_module:
|
||||
lora_embeddings_tensor = embeddings[
|
||||
embedding_modules[embeddings_module]
|
||||
].to(device=device, dtype=dtype)
|
||||
if pin_memory:
|
||||
lora_embeddings_tensor = lora_embeddings_tensor.pin_memory()
|
||||
loras[module_name] = LoRALayerWeights.from_config(
|
||||
module_name, peft_helper, lora_embeddings_tensor
|
||||
module_name, peft_helper
|
||||
)
|
||||
|
||||
if is_lora_a:
|
||||
@ -206,15 +188,17 @@ class LoRAModel:
|
||||
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
|
||||
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
|
||||
lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt")
|
||||
new_embeddings_tensor_path = os.path.join(
|
||||
lora_dir, "new_embeddings.safetensors"
|
||||
)
|
||||
new_embeddings_bin_file_path = os.path.join(lora_dir, "new_embeddings.bin")
|
||||
# new_embeddings_tensor_path = os.path.join(
|
||||
# lora_dir, "new_embeddings.safetensors"
|
||||
# )
|
||||
# new_embeddings_bin_file_path = os.path.join(lora_dir, "new_embeddings.bin")
|
||||
tensors: dict[str, torch.Tensor] = {}
|
||||
unexpected_modules: list[list[str] | str] = []
|
||||
|
||||
def check_unexpected_modules(modules: dict):
|
||||
for lora_module in modules.keys(): # noqa
|
||||
if is_base_embeddding_weights(lora_module):
|
||||
continue
|
||||
module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper)
|
||||
# Handle FSDP file format where experts.base_layer is the
|
||||
# gate_up_proj and experts is the down_proj
|
||||
@ -300,21 +284,12 @@ class LoRAModel:
|
||||
else:
|
||||
raise ValueError(f"{lora_dir} doesn't contain tensors")
|
||||
|
||||
embeddings = None
|
||||
if os.path.isfile(new_embeddings_tensor_path):
|
||||
embeddings = safetensors.torch.load_file(new_embeddings_tensor_path)
|
||||
elif os.path.isfile(new_embeddings_bin_file_path):
|
||||
embeddings = torch.load(
|
||||
new_embeddings_bin_file_path, map_location=device, weights_only=True
|
||||
)
|
||||
|
||||
return cls.from_lora_tensors(
|
||||
lora_model_id=get_lora_id() if lora_model_id is None else lora_model_id,
|
||||
tensors=tensors,
|
||||
peft_helper=peft_helper,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
embeddings=embeddings,
|
||||
target_embedding_padding=target_embedding_padding,
|
||||
embedding_modules=embedding_modules,
|
||||
embedding_padding_modules=embedding_padding_modules,
|
||||
@ -474,7 +449,6 @@ class LoRAModelManager:
|
||||
index,
|
||||
module_lora.lora_a,
|
||||
module_lora.lora_b,
|
||||
module_lora.embeddings_tensor,
|
||||
)
|
||||
else:
|
||||
module.reset_lora(index)
|
||||
@ -505,7 +479,6 @@ class LoRAModelManager:
|
||||
self.lora_index_to_id,
|
||||
self.lora_slots + 1,
|
||||
self.vocab_size,
|
||||
self.lora_config.lora_extra_vocab_size,
|
||||
)
|
||||
|
||||
def remove_all_adapters(self):
|
||||
@ -616,7 +589,6 @@ class LoRAModelManager:
|
||||
if parts[-1] in embedding_modules:
|
||||
input_dim = (
|
||||
module.base_layer.org_vocab_size
|
||||
+ self.lora_config.lora_extra_vocab_size
|
||||
if hasattr(module.base_layer, "org_vocab_size")
|
||||
else module.base_layer.weight.shape[1]
|
||||
)
|
||||
@ -625,11 +597,6 @@ class LoRAModelManager:
|
||||
if hasattr(module.base_layer, "embedding_dim")
|
||||
else module.base_layer.weight.shape[0]
|
||||
)
|
||||
embeddings_tensor_dim = (
|
||||
module.base_layer.embedding_dim
|
||||
if hasattr(module.base_layer, "embedding_dim")
|
||||
else module.base_layer.weight.shape[1]
|
||||
)
|
||||
lora = LoRALayerWeights.create_dummy_lora_weights(
|
||||
module_name,
|
||||
input_dim,
|
||||
@ -637,7 +604,6 @@ class LoRAModelManager:
|
||||
rank,
|
||||
module.lora_a_stacked[0].dtype,
|
||||
"cpu",
|
||||
embeddings_tensor_dim=embeddings_tensor_dim,
|
||||
)
|
||||
else:
|
||||
lora = LoRALayerWeights.create_dummy_lora_weights(
|
||||
|
||||
@ -31,7 +31,6 @@ class PunicaWrapperABC(ABC):
|
||||
lora_index_to_id: list[int | None],
|
||||
max_loras: int,
|
||||
vocab_size: int,
|
||||
extra_vocab_size: int,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
@ -172,8 +171,11 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
||||
lora_index_to_id: list[int | None],
|
||||
max_loras: int,
|
||||
vocab_size: int,
|
||||
extra_vocab_size: int,
|
||||
):
|
||||
# NOTE We have remove lora extra vocab support for now. So we set
|
||||
# extra_vocab_size alwayzs to 0, and extra_vocab_size will be removed.
|
||||
|
||||
extra_vocab_size = 0
|
||||
(
|
||||
base_indices,
|
||||
sampler_indices,
|
||||
@ -285,12 +287,9 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
||||
lora_index_to_id: list[int | None],
|
||||
max_loras: int,
|
||||
vocab_size: int,
|
||||
extra_vocab_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
self._update_base_metadata(
|
||||
mapping, lora_index_to_id, max_loras, vocab_size, extra_vocab_size
|
||||
)
|
||||
self._update_base_metadata(mapping, lora_index_to_id, max_loras, vocab_size)
|
||||
|
||||
if mapping.is_prefill:
|
||||
# Update metadata required for prefill-related operators.
|
||||
|
||||
@ -65,13 +65,10 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
lora_index_to_id: list[int | None],
|
||||
max_loras: int,
|
||||
vocab_size: int,
|
||||
extra_vocab_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
self.is_prefill = mapping.is_prefill
|
||||
self._update_base_metadata(
|
||||
mapping, lora_index_to_id, max_loras, vocab_size, extra_vocab_size
|
||||
)
|
||||
self._update_base_metadata(mapping, lora_index_to_id, max_loras, vocab_size)
|
||||
|
||||
# Prepare cuda kernel metadata tensors
|
||||
self.token_mapping_meta.prepare_tensors(self.token_lora_indices)
|
||||
|
||||
@ -292,7 +292,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
||||
lora_index_to_id: list[int | None],
|
||||
max_loras: int,
|
||||
vocab_size: int,
|
||||
extra_vocab_size: int,
|
||||
):
|
||||
# Make sure we don't accidentally collect outside operations
|
||||
torch_xla.sync()
|
||||
@ -313,7 +312,7 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
||||
lora_index_to_id,
|
||||
max_loras,
|
||||
vocab_size,
|
||||
extra_vocab_size,
|
||||
0, # extra_vocab_size
|
||||
"cpu",
|
||||
)
|
||||
self._token_lora_indices = self._pad_to_shape(
|
||||
|
||||
@ -43,13 +43,10 @@ class PunicaWrapperXPU(PunicaWrapperBase):
|
||||
lora_index_to_id: list[int | None],
|
||||
max_loras: int,
|
||||
vocab_size: int,
|
||||
extra_vocab_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
self.is_prefill = mapping.is_prefill
|
||||
self._update_base_metadata(
|
||||
mapping, lora_index_to_id, max_loras, vocab_size, extra_vocab_size
|
||||
)
|
||||
self._update_base_metadata(mapping, lora_index_to_id, max_loras, vocab_size)
|
||||
|
||||
def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor:
|
||||
return torch.narrow(self._token_lora_indices, 0, 0, x.size(0))
|
||||
|
||||
@ -166,6 +166,16 @@ def parse_fine_tuned_lora_name(
|
||||
raise ValueError(f"{name} is unsupported LoRA weight")
|
||||
|
||||
|
||||
def is_base_embeddding_weights(name: str) -> bool:
|
||||
# hardcoded subfixes for input & output embedding weights
|
||||
input_embedding_subfix = ".embed_tokens.base_layer.weight"
|
||||
output_embedding_subfix = ".lm_head.base_layer.weight"
|
||||
|
||||
return name.endswith(input_embedding_subfix) or name.endswith(
|
||||
output_embedding_subfix
|
||||
)
|
||||
|
||||
|
||||
def is_regex_target_modules(
|
||||
load_modules: str | list[str], expected_lora_modules: list[str]
|
||||
) -> bool:
|
||||
|
||||
@ -121,8 +121,7 @@ class WorkerLoRAManager:
|
||||
lora_model_id=lora_request.lora_int_id,
|
||||
device="cpu",
|
||||
dtype=self.lora_config.lora_dtype,
|
||||
target_embedding_padding=self.vocab_size
|
||||
+ self.lora_config.lora_extra_vocab_size,
|
||||
target_embedding_padding=self.vocab_size,
|
||||
embedding_modules=self.embedding_modules,
|
||||
embedding_padding_modules=self.embedding_padding_modules,
|
||||
tensorizer_config_dict=lora_request.tensorizer_config_dict,
|
||||
@ -143,12 +142,6 @@ class WorkerLoRAManager:
|
||||
# For BadRequestError
|
||||
raise e
|
||||
|
||||
if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
|
||||
raise ValueError(
|
||||
f"LoRA added vocab size {lora.extra_vocab_size} "
|
||||
f"is greater than lora_extra_vocab_size "
|
||||
f"{self.lora_config.lora_extra_vocab_size}."
|
||||
)
|
||||
return lora
|
||||
|
||||
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
|
||||
|
||||
@ -46,7 +46,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE,
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
@ -261,29 +260,16 @@ class GraniteModel(nn.Module):
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
lora_vocab = (
|
||||
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
|
||||
if lora_config
|
||||
else 0
|
||||
)
|
||||
self.vocab_size = config.vocab_size + lora_vocab
|
||||
self.org_vocab_size = config.vocab_size
|
||||
|
||||
if get_pp_group().is_first_rank or (
|
||||
config.tie_word_embeddings and get_pp_group().is_last_rank
|
||||
):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.vocab_size,
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
||||
# We need bigger padding if using lora for kernel
|
||||
# compatibility
|
||||
if not lora_config
|
||||
else lora_config.lora_vocab_padding_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
else:
|
||||
@ -420,28 +406,18 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.model = GraniteModel(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
if get_pp_group().is_last_rank:
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.unpadded_vocab_size,
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
||||
# We need bigger padding if using lora for kernel
|
||||
# compatibility
|
||||
if not lora_config
|
||||
else lora_config.lora_vocab_padding_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
@ -453,7 +429,7 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
logit_scale /= config.logits_scaling
|
||||
|
||||
self.logits_processor = LogitsProcessor(
|
||||
self.unpadded_vocab_size, config.vocab_size, scale=logit_scale
|
||||
config.vocab_size, scale=logit_scale
|
||||
)
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
@ -47,7 +47,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE,
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
@ -368,24 +367,18 @@ class LlamaModel(nn.Module):
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
lora_vocab = (
|
||||
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
|
||||
if lora_config
|
||||
else 0
|
||||
)
|
||||
self.vocab_size = config.vocab_size + lora_vocab
|
||||
self.org_vocab_size = config.vocab_size
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
if get_pp_group().is_first_rank or (
|
||||
config.tie_word_embeddings and get_pp_group().is_last_rank
|
||||
):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
else:
|
||||
@ -562,9 +555,7 @@ class LlamaForCausalLM(
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.model = self._init_model(
|
||||
vllm_config=vllm_config,
|
||||
@ -573,20 +564,9 @@ class LlamaForCausalLM(
|
||||
)
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.unpadded_vocab_size,
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=(
|
||||
DEFAULT_VOCAB_PADDING_SIZE
|
||||
# We need bigger padding if using lora for kernel
|
||||
# compatibility
|
||||
if not lora_config
|
||||
else lora_config.lora_vocab_padding_size
|
||||
),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
@ -595,7 +575,7 @@ class LlamaForCausalLM(
|
||||
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(
|
||||
self.unpadded_vocab_size, config.vocab_size, logit_scale
|
||||
config.vocab_size, scale=logit_scale
|
||||
)
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
@ -51,7 +51,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE,
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
@ -301,23 +300,18 @@ class MixtralModel(nn.Module):
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
lora_vocab = (
|
||||
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
|
||||
if lora_config
|
||||
else 0
|
||||
)
|
||||
self.vocab_size = config.vocab_size + lora_vocab
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
self.org_vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
)
|
||||
|
||||
self.enable_eplb = parallel_config.enable_eplb
|
||||
@ -508,34 +502,24 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.model = MixtralModel(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.unpadded_vocab_size,
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
||||
# We need bigger padding if using lora for kernel
|
||||
# compatibility
|
||||
if not lora_config
|
||||
else lora_config.lora_vocab_padding_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(
|
||||
self.unpadded_vocab_size, config.vocab_size
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
@ -74,5 +74,5 @@ class TeleFLMForCausalLM(LlamaForCausalLM):
|
||||
self.output_mult = self.config.output_mult / self.mup_scale_factor
|
||||
logit_scale = self.output_mult
|
||||
self.logits_processor = LogitsProcessor(
|
||||
self.unpadded_vocab_size, self.config.vocab_size, logit_scale
|
||||
self.config.vocab_size, scale=logit_scale
|
||||
)
|
||||
|
||||
@ -219,9 +219,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.hidden_size = model_config.get_hidden_size()
|
||||
self.vocab_size = model_config.get_vocab_size()
|
||||
|
||||
if self.lora_config is not None:
|
||||
self.vocab_size += self.lora_config.lora_extra_vocab_size
|
||||
|
||||
# Multi-modal data support
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
self.uses_mrope = model_config.uses_mrope
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user