diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index d8ff9339bb49b..9d38ec5422794 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -250,6 +250,16 @@ def olmoe_lora_files(): return snapshot_download(repo_id="jeeejeee/olmoe-instruct-text2sql-spider") +@pytest.fixture(scope="session") +def qwen3_lora_files(): + return snapshot_download(repo_id="charent/self_cognition_Alice") + + +@pytest.fixture(scope="session") +def llama32_lora_files(): + return snapshot_download(repo_id="jeeejeee/llama32-3b-text2sql-spider") + + @pytest.fixture def reset_default_device(): """ diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 8f18f01441932..9df3a07a9e5e9 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -136,7 +136,6 @@ def populate_loras( id_to_index: list[int | None], layer: BaseLayerWithLoRA, layer_weights: torch.Tensor, - generate_embeddings_tensor: int = 0, repeats: int = 1, ) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]: """This method populates the lora layers with lora weights. @@ -148,8 +147,6 @@ def populate_loras( layer: the LoRAlayer to populate. layer_weights: the PyTorch tensor containing the layer's weights. - generate_embeddings_tensor: whether to generate an - embeddings tensor for each LoRA. repeats: must only be set for column parallel packed layers. Indicates the number of loras to compose together to create a single lora layer. @@ -171,7 +168,6 @@ def populate_loras( sublora = DummyLoRAManager(layer_weights.device).init_random_lora( module_name=f"fake_{i}", weight=layer_weights, - generate_embeddings_tensor=generate_embeddings_tensor, ) sublora.lora_b = sublora.lora_b[ (sublora_len * i) : (sublora_len * (i + 1)), : @@ -185,7 +181,6 @@ def populate_loras( slot_idx, lora_a=lora.lora_a, lora_b=lora.lora_b, - embeddings_tensor=lora.embeddings_tensor, ) lora_dict[lora_id] = lora @@ -306,7 +301,6 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: id_to_index, max_loras, vocab_size, - lora_config.lora_extra_vocab_size, ) lora_result = lora_embedding(torch.cat(inputs)) @@ -344,7 +338,6 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: id_to_index, max_loras, vocab_size, - lora_config.lora_extra_vocab_size, ) lora_result = lora_embedding(torch.cat(inputs)) @@ -354,149 +347,6 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) -@torch.inference_mode() -# @pytest.mark.skip( -# reason="Fails when loras are in any slot other than the first.") -@pytest.mark.parametrize("num_loras", [1, 2, 4]) -@pytest.mark.parametrize("device", DEVICES) -@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) -@pytest.mark.parametrize("stage", STAGES) -def test_embeddings_with_new_embeddings( - dist_init, num_loras, device, vocab_size, stage -) -> None: - if current_platform.is_cuda_alike(): - torch.cuda.set_device(device) - - torch.set_default_device(device) - max_loras = 8 - punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) - assert check_punica_wrapper(punica_wrapper) - lora_config = LoRAConfig( - max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16 - ) - - def create_random_embedding_layer(): - embedding = VocabParallelEmbedding(vocab_size, 256) - embedding_data = torch.rand_like(embedding.weight.data) - embedding.weight.data = embedding_data - embedding.weight.data[vocab_size:, :] = 0 - expanded_embedding = VocabParallelEmbedding( - vocab_size + lora_config.lora_extra_vocab_size * max_loras, - 256, - org_num_embeddings=vocab_size, - ) - expanded_embedding.weight.data[:vocab_size, :] = embedding_data - # We need to deepcopy the embedding as it will be modified - # in place - lora_embedding = VocabParallelEmbeddingWithLoRA(deepcopy(expanded_embedding)) - lora_embedding.create_lora_weights(max_loras, lora_config) - - return expanded_embedding, lora_embedding - - for i in range(NUM_RANDOM_SEEDS): - set_random_seed(i) - - id_to_index = get_random_id_to_index(num_loras, max_loras) - expanded_embedding, lora_embedding = create_random_embedding_layer() - lora_dict, _ = populate_loras( - id_to_index, - layer=lora_embedding, - layer_weights=torch.zeros( - (256, vocab_size + lora_config.lora_extra_vocab_size) - ), - generate_embeddings_tensor=256, - ) - - lora_embedding.set_mapping(punica_wrapper) - # All embeddings tensors have the same shape. - embeddings_tensors = [ - lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys()) - ] - embeddings_tensor_len = embeddings_tensors[0].shape[0] - - # Add empty embeddings_tensors for unoccupied lora slots. - for _ in range(max_loras - len(embeddings_tensors)): - embeddings_tensors.append(torch.zeros(embeddings_tensors[0].shape)) - - inputs, index_mapping, prompt_mapping = create_random_inputs( - active_lora_ids=list(lora_dict.keys()), - num_inputs=num_loras * 3, - input_size=(200,), - input_range=(1, vocab_size), - device=device, - ) - lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) - punica_wrapper.update_metadata( - lora_mapping, - id_to_index, - max_loras, - vocab_size, - lora_config.lora_extra_vocab_size, - ) - original_inputs = deepcopy(inputs) - - # Force some of the inputs to be in the extended embeddings range - # to guarantee that their behavior is tested. - for input_, original_input_, lora_id in zip( - inputs, original_inputs, prompt_mapping - ): - embedding_id = lora_id - 1 - input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len) - original_input_[-1] = vocab_size - input_[-2] = vocab_size + ((embedding_id + 1) * embeddings_tensor_len - 1) - original_input_[-2] = vocab_size + embeddings_tensor_len - 1 - - expanded_embedding.weight[ - vocab_size : vocab_size + (embeddings_tensor_len * max_loras) - ] = torch.cat(embeddings_tensors) - - lora_result = lora_embedding(torch.cat(original_inputs)) - - expected_results: list[torch.Tensor] = [] - for input_, original_input_, lora_id in zip( - inputs, original_inputs, prompt_mapping - ): - lora = lora_dict[lora_id] - result = expanded_embedding(input_) - after_a = F.embedding( - original_input_, - lora.lora_a.T, - ) - result += after_a @ lora.lora_b.T - expected_results.append(result) - expected_result = torch.cat(expected_results) - - rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) - - # Check that resetting the lora weights succeeds - - for slot_idx in range(max_loras): - lora_embedding.reset_lora(slot_idx) - - inputs, index_mapping, prompt_mapping = create_random_inputs( - active_lora_ids=[0], - num_inputs=num_loras * 3, - input_size=(200,), - input_range=(1, vocab_size), - device=device, - ) - original_inputs = deepcopy(inputs) - lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) - punica_wrapper.update_metadata( - lora_mapping, - id_to_index, - max_loras, - vocab_size, - lora_config.lora_extra_vocab_size, - ) - lora_result = lora_embedding(torch.cat(original_inputs)) - expected_result = expanded_embedding(torch.cat(inputs)) - - rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) - - @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("device", DEVICES) @@ -518,16 +368,13 @@ def test_lm_head_logits_processor( def _pretest(): linear = ParallelLMHead( - vocab_size + lora_config.lora_extra_vocab_size, - 1024, - vocab_size, + num_embeddings=vocab_size, + embedding_dim=1024, params_dtype=torch.float16, ) linear.weight.data = torch.rand_like(linear.weight.data) linear.weight.data[:, vocab_size:] = 0 - logits_processor = LogitsProcessor( - vocab_size + lora_config.lora_extra_vocab_size, vocab_size - ) + logits_processor = LogitsProcessor(vocab_size) lora_logits_processor = LogitsProcessorWithLoRA( logits_processor, 1024, linear.weight.dtype, linear.weight.device, None ) @@ -541,15 +388,12 @@ def test_lm_head_logits_processor( id_to_index = get_random_id_to_index(num_loras, max_loras) linear, logits_processor, lora_logits_processor = _pretest() lora_logits_processor.set_mapping(punica_wrapper) - # NOTE: all the generated loras share the same embeddings tensor. + lora_dict, _ = populate_loras( id_to_index, layer=lora_logits_processor, layer_weights=linear.weight, - generate_embeddings_tensor=1024, ) - embeddings_tensor = list(lora_dict.values())[0].embeddings_tensor - embeddings_tensor_len = embeddings_tensor.shape[0] inputs, index_mapping, prompt_mapping = create_random_inputs( active_lora_ids=list(lora_dict.keys()), @@ -565,7 +409,6 @@ def test_lm_head_logits_processor( id_to_index, max_loras, vocab_size, - lora_config.lora_extra_vocab_size, ) input_ = torch.rand(20, 1024) @@ -575,23 +418,16 @@ def test_lm_head_logits_processor( original_lm_head = deepcopy(linear) - linear.weight[ - logits_processor.org_vocab_size : logits_processor.org_vocab_size - + embeddings_tensor_len - ] = embeddings_tensor - - logits_processor.org_vocab_size = vocab_size + lora_config.lora_extra_vocab_size expected_results: list[torch.Tensor] = [] for input_, lora_id in zip(inputs, prompt_mapping): lora = lora_dict[lora_id] result = logits_processor._get_logits( hidden_states=input_, lm_head=linear, embedding_bias=None ) - result[:, vocab_size + embeddings_tensor_len :] = float("-inf") + result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling expected_results.append(result) expected_result = torch.cat(expected_results) - logits_processor.org_vocab_size = vocab_size # Check that resetting the lora weights succeeds @@ -612,7 +448,6 @@ def test_lm_head_logits_processor( id_to_index, max_loras, vocab_size, - lora_config.lora_extra_vocab_size, ) lora_result = lora_logits_processor._get_logits( @@ -694,7 +529,6 @@ def test_linear_replicated( id_to_index, max_loras, 512, - lora_config.lora_extra_vocab_size, ) lora_result = lora_linear(torch.cat(inputs))[0] @@ -726,7 +560,10 @@ def test_linear_replicated( lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata( - lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size + lora_mapping, + id_to_index, + max_loras, + 512, ) lora_result = lora_linear(torch.cat(inputs))[0] @@ -817,7 +654,6 @@ def test_linear_parallel( id_to_index, max_loras, 512, - lora_config.lora_extra_vocab_size, ) lora_result = lora_linear(torch.cat(inputs))[0] @@ -849,7 +685,10 @@ def test_linear_parallel( lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata( - lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size + lora_mapping, + id_to_index, + max_loras, + 512, ) lora_result = lora_linear(torch.cat(inputs))[0] @@ -963,7 +802,6 @@ def test_column_parallel_packed( id_to_index, max_loras, 512, - lora_config.lora_extra_vocab_size, ) lora_result = lora_linear(torch.cat(inputs))[0] @@ -1000,7 +838,6 @@ def test_column_parallel_packed( id_to_index, max_loras, 512, - lora_config.lora_extra_vocab_size, ) lora_result = lora_linear(torch.cat(inputs))[0] diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py index 7bbd1e364d19e..18704fa6e45de 100644 --- a/tests/lora/test_llama_tp.py +++ b/tests/lora/test_llama_tp.py @@ -13,17 +13,27 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from ..utils import VLLM_PATH, create_new_process_for_each_test, multi_gpu_test -MODEL_PATH = "meta-llama/Llama-2-7b-hf" +PROMPT_TEMPLATE = """<|eot_id|><|start_header_id|>user<|end_header_id|> +I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request. +" +##Instruction: +candidate_poll contains tables such as candidate, people. Table candidate has columns such as Candidate_ID, People_ID, Poll_Source, Date, Support_rate, Consider_rate, Oppose_rate, Unsure_rate. Candidate_ID is the primary key. +Table people has columns such as People_ID, Sex, Name, Date_of_Birth, Height, Weight. People_ID is the primary key. +The People_ID of candidate is the foreign key of People_ID of people. +###Input: +{context} +###Response:<|eot_id|><|start_header_id|>assistant<|end_header_id|> +""" # noqa: E501 EXPECTED_LORA_OUTPUT = [ - " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501 - " SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ", - " SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ", # noqa: E501 - " SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", # noqa: E501 - " SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ", - " SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' ", # noqa: E501 + "SELECT count(*) FROM candidate", + "SELECT count(*) FROM candidate", + "SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501 + "SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501 ] +MODEL_PATH = "meta-llama/Llama-3.2-3B-Instruct" + def do_sample( llm: vllm.LLM, @@ -32,18 +42,19 @@ def do_sample( tensorizer_config_dict: dict | None = None, ) -> list[str]: prompts = [ - "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 - "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501 - "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]", # noqa: E501 - "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501 - "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", # noqa: E501 - "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]", # noqa: E501 + PROMPT_TEMPLATE.format(context="How many candidates are there?"), + PROMPT_TEMPLATE.format(context="Count the number of candidates."), + PROMPT_TEMPLATE.format( + context="Which poll resource provided the most number of candidate information?" # noqa: E501 + ), + PROMPT_TEMPLATE.format( + context="Return the poll resource associated with the most candidates." + ), ] sampling_params = vllm.SamplingParams( - temperature=0, max_tokens=256, skip_special_tokens=False, stop=["[/assistant]"] + temperature=0, max_tokens=64, stop=["<|im_end|>"] ) - if tensorizer_config_dict is not None: outputs = llm.generate( prompts, @@ -75,13 +86,15 @@ def do_sample( return generated_texts -def generate_and_test(llm, sql_lora_files, tensorizer_config_dict: dict | None = None): +def generate_and_test( + llm, llama32_lora_files, tensorizer_config_dict: dict | None = None +): print("lora adapter created") print("lora 1") assert ( do_sample( llm, - sql_lora_files, + llama32_lora_files, tensorizer_config_dict=tensorizer_config_dict, lora_id=1, ) @@ -92,7 +105,7 @@ def generate_and_test(llm, sql_lora_files, tensorizer_config_dict: dict | None = assert ( do_sample( llm, - sql_lora_files, + llama32_lora_files, tensorizer_config_dict=tensorizer_config_dict, lora_id=2, ) @@ -104,51 +117,52 @@ def generate_and_test(llm, sql_lora_files, tensorizer_config_dict: dict | None = @create_new_process_for_each_test() @pytest.mark.parametrize("cudagraph_specialize_lora", [True, False]) -def test_llama_lora(sql_lora_files, cudagraph_specialize_lora: bool): +def test_llama_lora(llama32_lora_files, cudagraph_specialize_lora: bool): llm = vllm.LLM( MODEL_PATH, - tokenizer=sql_lora_files, enable_lora=True, # also test odd max_num_seqs - max_num_seqs=13, + max_num_seqs=7, + max_model_len=1024, max_loras=4, compilation_config=vllm.config.CompilationConfig( cudagraph_specialize_lora=cudagraph_specialize_lora, ), ) - generate_and_test(llm, sql_lora_files) + generate_and_test(llm, llama32_lora_files) @multi_gpu_test(num_gpus=4) -def test_llama_lora_tp4(sql_lora_files): +def test_llama_lora_tp4(llama32_lora_files): llm = vllm.LLM( MODEL_PATH, - tokenizer=sql_lora_files, enable_lora=True, - max_num_seqs=16, + max_num_seqs=7, + max_model_len=1024, max_loras=4, tensor_parallel_size=4, ) - generate_and_test(llm, sql_lora_files) + generate_and_test(llm, llama32_lora_files) @multi_gpu_test(num_gpus=4) -def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files): +def test_llama_lora_tp4_fully_sharded_loras(llama32_lora_files): llm = vllm.LLM( MODEL_PATH, - tokenizer=sql_lora_files, enable_lora=True, - max_num_seqs=16, + max_num_seqs=8, max_loras=4, + max_model_len=1024, tensor_parallel_size=4, fully_sharded_loras=True, ) - generate_and_test(llm, sql_lora_files) + generate_and_test(llm, llama32_lora_files) @multi_gpu_test(num_gpus=2) def test_tp2_serialize_and_deserialize_lora( - tmp_path, sql_lora_files, sql_lora_huggingface_id + tmp_path, + llama32_lora_files, ): # Run the tensorizing of the LoRA adapter and the model in a subprocess # to guarantee cleanup @@ -157,7 +171,7 @@ def test_tp2_serialize_and_deserialize_lora( model_name = "model-rank-%03d.tensors" model_ref = MODEL_PATH - lora_path = sql_lora_huggingface_id + lora_path = llama32_lora_files suffix = "test" try: result = subprocess.run( @@ -195,12 +209,12 @@ def test_tp2_serialize_and_deserialize_lora( loaded_llm = LLM( model=model_ref, - tokenizer=sql_lora_files, load_format="tensorizer", enable_lora=True, enforce_eager=True, model_loader_extra_config=tensorizer_config, - max_num_seqs=13, + max_num_seqs=7, + max_model_len=1024, tensor_parallel_size=2, max_loras=2, ) @@ -211,7 +225,7 @@ def test_tp2_serialize_and_deserialize_lora( print("lora 1") assert ( do_sample( - loaded_llm, sql_lora_files, tensorizer_config_dict=tc_as_dict, lora_id=1 + loaded_llm, llama32_lora_files, tensorizer_config_dict=tc_as_dict, lora_id=1 ) == EXPECTED_LORA_OUTPUT ) diff --git a/tests/lora/test_lora_functions.py b/tests/lora/test_lora_functions.py index e914393fee8aa..1c692630284d0 100644 --- a/tests/lora/test_lora_functions.py +++ b/tests/lora/test_lora_functions.py @@ -13,8 +13,8 @@ from vllm.entrypoints.openai.api_server import ( from vllm.lora.request import LoRARequest from vllm.v1.engine.llm_engine import LLMEngine -MODEL_PATH = "meta-llama/Llama-2-7b-hf" -LORA_MODULE_PATH = "yard1/llama-2-7b-sql-lora-test" +MODEL_PATH = "Qwen/Qwen3-0.6B" +LORA_MODULE_PATH = "charent/self_cognition_Alice" LORA_RANK = 8 diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index e7816031142e3..24d4dfca46d62 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -48,9 +48,6 @@ DEFAULT_DTYPE = torch.get_default_dtype() @pytest.mark.parametrize("device", DEVICES) def test_from_lora_tensors(sql_lora_files, device): tensors = load_file(os.path.join(sql_lora_files, "adapter_model.safetensors")) - new_embeddings = load_file( - os.path.join(sql_lora_files, "new_embeddings.safetensors") - ) peft_helper = PEFTHelper.from_local_dir( sql_lora_files, max_position_embeddings=4096 @@ -60,7 +57,6 @@ def test_from_lora_tensors(sql_lora_files, device): tensors, peft_helper=peft_helper, device=device, - embeddings=new_embeddings, embedding_modules=EMBEDDING_MODULES, embedding_padding_modules=EMBEDDING_PADDING_MODULES, ) @@ -76,18 +72,6 @@ def test_from_lora_tensors(sql_lora_files, device): f"{lora.lora_a.shape=}, {lora.lora_b.shape=}" ) assert lora.lora_a.shape[0] == 8 - embeddings_module = next( - (k for k in EMBEDDING_MODULES if k in module_name), None - ) - if embeddings_module: - assert torch.equal( - lora.embeddings_tensor, - new_embeddings[EMBEDDING_MODULES[embeddings_module]].to( - device=lora.embeddings_tensor.device - ), - ) - else: - assert lora.embeddings_tensor is None def create_lora( @@ -552,9 +536,7 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, tmp_path worker_adapter_manager = WorkerLoRAManager( vllm_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES ) - worker_adapter_manager.vocab_size = ( - dummy_model_gate_up.unpadded_vocab_size - lora_config.lora_extra_vocab_size - ) + worker_adapter_manager.vocab_size = dummy_model_gate_up.unpadded_vocab_size worker_adapter_manager.create_lora_manager(dummy_model_gate_up) dummy_lora_files = f"{tmp_path}/lora_adapter" diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index c97f8debd1b9a..b163559a9414d 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -20,11 +20,12 @@ from vllm.lora.models import LoRAMapping from vllm.lora.request import LoRARequest from vllm.v1.worker.gpu_worker import Worker +MODEL_PATH = "Qwen/Qwen3-0.6B" NUM_LORAS = 16 @patch.dict(os.environ, {"RANK": "0"}) -def test_worker_apply_lora(sql_lora_files): +def test_worker_apply_lora(qwen3_lora_files): def set_active_loras(worker: Worker, lora_requests: list[LoRARequest]): lora_mapping = LoRAMapping([], []) @@ -34,9 +35,10 @@ def test_worker_apply_lora(sql_lora_files): vllm_config = VllmConfig( model_config=ModelConfig( - "meta-llama/Llama-2-7b-hf", + MODEL_PATH, seed=0, dtype="float16", + max_model_len=127, enforce_eager=True, ), load_config=LoadConfig( @@ -73,7 +75,7 @@ def test_worker_apply_lora(sql_lora_files): assert worker.list_loras() == set() lora_requests = [ - LoRARequest(str(i + 1), i + 1, sql_lora_files) for i in range(NUM_LORAS) + LoRARequest(str(i + 1), i + 1, qwen3_lora_files) for i in range(NUM_LORAS) ] set_active_loras(worker, lora_requests) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index d30b77f094665..6aba5299b5829 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -28,7 +28,6 @@ class DummyLoRAManager: module_name: str, weight: torch.Tensor, rank: int = 8, - generate_embeddings_tensor: int = 0, ): lora = LoRALayerWeights( module_name, @@ -41,13 +40,6 @@ class DummyLoRAManager: [weight.shape[0], rank], dtype=weight.dtype, device=self._device ), ) - if generate_embeddings_tensor: - lora.embeddings_tensor = torch.rand( - 5, - generate_embeddings_tensor, - dtype=weight.dtype, - device=self._device, - ) self.set_module_lora(module_name, lora) return lora diff --git a/vllm/config/lora.py b/vllm/config/lora.py index 84e92eef40077..072e0ec2104f5 100644 --- a/vllm/config/lora.py +++ b/vllm/config/lora.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import hashlib -from typing import TYPE_CHECKING, Any, ClassVar, Literal +from typing import TYPE_CHECKING, Any, Literal import torch from pydantic import ConfigDict, Field, model_validator @@ -11,7 +11,6 @@ from typing_extensions import Self from vllm.config.utils import config from vllm.logger import init_logger -from vllm.platforms import current_platform if TYPE_CHECKING: from vllm.config import ModelConfig @@ -46,19 +45,6 @@ class LoRAConfig: `max_loras`.""" lora_dtype: torch.dtype | LoRADType = "auto" """Data type for LoRA. If auto, will default to base model dtype.""" - lora_extra_vocab_size: LoRAExtraVocabSize = Field( - default=256, - deprecated=( - "`lora_extra_vocab_size` is deprecated and will be removed " - "in v0.12.0. Additional vocabulary support for " - "LoRA adapters is being phased out." - ), - ) - """(Deprecated) Maximum size of extra vocabulary that can be present in a - LoRA adapter. Will be removed in v0.12.0.""" - lora_vocab_padding_size: ClassVar[int] = ( - current_platform.get_lora_vocab_padding_size() - ) default_mm_loras: dict[str, str] | None = None """Dictionary mapping specific modalities to LoRA model paths; this field is only applicable to multimodal models and should be leveraged when a @@ -87,8 +73,6 @@ class LoRAConfig: factors.append(self.max_loras) factors.append(self.fully_sharded_loras) factors.append(self.lora_dtype) - factors.append(self.lora_extra_vocab_size) - factors.append(self.lora_vocab_padding_size) hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 74828bc109cbe..bcb90119f9b04 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -484,7 +484,6 @@ class EngineArgs: fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras max_cpu_loras: int | None = LoRAConfig.max_cpu_loras lora_dtype: str | torch.dtype | None = LoRAConfig.lora_dtype - lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight num_gpu_blocks_override: int | None = CacheConfig.num_gpu_blocks_override @@ -1011,9 +1010,6 @@ class EngineArgs: ) lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"]) lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"]) - lora_group.add_argument( - "--lora-extra-vocab-size", **lora_kwargs["lora_extra_vocab_size"] - ) lora_group.add_argument( "--lora-dtype", **lora_kwargs["lora_dtype"], @@ -1680,7 +1676,6 @@ class EngineArgs: max_loras=self.max_loras, default_mm_loras=self.default_mm_loras, fully_sharded_loras=self.fully_sharded_loras, - lora_extra_vocab_size=self.lora_extra_vocab_size, lora_dtype=self.lora_dtype, max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 diff --git a/vllm/lora/layers/base.py b/vllm/lora/layers/base.py index 0c7e806848892..62326c05b2bd1 100644 --- a/vllm/lora/layers/base.py +++ b/vllm/lora/layers/base.py @@ -44,7 +44,6 @@ class BaseLayerWithLoRA(nn.Module): index: int, lora_a: torch.Tensor, lora_b: torch.Tensor, - embeddings_tensor: torch.Tensor | None, ): """Overwrites lora tensors at index.""" ... diff --git a/vllm/lora/layers/base_linear.py b/vllm/lora/layers/base_linear.py index 3db4165e20176..e85c5bd70b072 100644 --- a/vllm/lora/layers/base_linear.py +++ b/vllm/lora/layers/base_linear.py @@ -96,7 +96,6 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): index: int, lora_a: torch.Tensor, lora_b: torch.Tensor, - embeddings_tensor: torch.Tensor | None, ): # Except for QKVParallelLinearWithLoRA and # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers diff --git a/vllm/lora/layers/column_parallel_linear.py b/vllm/lora/layers/column_parallel_linear.py index 637ded9b2a0f0..273c4950e3239 100644 --- a/vllm/lora/layers/column_parallel_linear.py +++ b/vllm/lora/layers/column_parallel_linear.py @@ -248,7 +248,6 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): index: int, lora_a: torch.Tensor, lora_b: torch.Tensor, - embeddings_tensor: torch.Tensor | None, ): self.reset_lora(index) diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index 3291c41fcda1e..adf30855cafc3 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -406,8 +406,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): index: int, lora_a: torch.Tensor, lora_b: torch.Tensor, - embeddings_tensor: torch.Tensor | None, - bias: torch.Tensor | None = None, ): """Overwrites lora tensors at index.""" self.reset_lora(index) diff --git a/vllm/lora/layers/logits_processor.py b/vllm/lora/layers/logits_processor.py index adc5e861f57fb..06f92652031e1 100644 --- a/vllm/lora/layers/logits_processor.py +++ b/vllm/lora/layers/logits_processor.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import math import torch import torch.nn as nn @@ -108,22 +107,13 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): ( max_loras, 1, - # Pad for kernel compatibility - math.ceil( - self.base_layer.vocab_size / lora_config.lora_vocab_padding_size - ) - * lora_config.lora_vocab_padding_size, + self.base_layer.vocab_size, lora_config.max_lora_rank, ), dtype=lora_config.lora_dtype, device=self.device, ) - self.embeddings_tensors = torch.full( - (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size), - fill_value=float("-inf"), - dtype=self.dtype, - device=self.device, - ) + if self.sharded_to_full_mapping is not None: self.sharded_to_full_mapping_gpu = torch.tensor( self.sharded_to_full_mapping, device=self.device, dtype=torch.long @@ -134,14 +124,12 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 self.lora_b_stacked[index] = 0 - self.embeddings_tensors[index] = float("-inf") def set_lora( self, index: int, lora_a: torch.Tensor, lora_b: torch.Tensor, - embeddings_tensor: torch.Tensor | None, ): self.reset_lora(index) self.lora_a_stacked[index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_( @@ -150,12 +138,6 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): self.lora_b_stacked[index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_( lora_b, non_blocking=True ) - if embeddings_tensor is not None: - self.embeddings_tensors[ - index, - : embeddings_tensor.shape[0], - : embeddings_tensor.shape[1], - ] = embeddings_tensor def _get_logits( self, @@ -193,39 +175,6 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): # token_id: [0, 1, 2, 3, 4, 5, -1, -1] logits = logits[:, self.sharded_to_full_mapping_gpu] - lora_logits = torch.empty( - self.embeddings_tensors.shape[0] + 1, - self.embeddings_tensors.shape[1], - hidden_states.shape[0], - dtype=self.embeddings_tensors.dtype, - device=self.embeddings_tensors.device, - ) - torch.matmul(self.embeddings_tensors, hidden_states.T, out=lora_logits[:-1]) - - neg_inf, pos_inf = current_platform.get_infinity_values(lora_logits.dtype) - - lora_logits[-1] = neg_inf - lora_logits = lora_logits.mT - indices_padded = self.punica_wrapper.sampler_indices_padded - - if current_platform.is_tpu() or current_platform.is_xpu(): - indices_padded = indices_padded[: logits.size(0)] - - lora_logits = ( - lora_logits.reshape( - lora_logits.shape[0] * lora_logits.shape[1], - lora_logits.shape[2], - ) - .index_select(0, indices_padded) - .nan_to_num_(nan=neg_inf, posinf=pos_inf, neginf=neg_inf) - ) - - logits[ - :, - self.base_layer.org_vocab_size : self.base_layer.org_vocab_size - + lora_logits.shape[1], - ] = lora_logits - lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_logits( logits, hidden_states, self.lora_a_stacked, self.lora_b_stacked, 1.0 ) diff --git a/vllm/lora/layers/vocal_parallel_embedding.py b/vllm/lora/layers/vocal_parallel_embedding.py index ca4ad8012e9c3..5b1f7886bc238 100644 --- a/vllm/lora/layers/vocal_parallel_embedding.py +++ b/vllm/lora/layers/vocal_parallel_embedding.py @@ -46,19 +46,10 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): self.embeddings_slice = None self.embeddings_weights = None - self.embeddings_tensors = torch.zeros( - ( - max_loras, - lora_config.lora_extra_vocab_size, - self.base_layer.embedding_dim, - ), - dtype=self.base_layer.weight.dtype, - device=self.base_layer.weight.device, - ) self.lora_a_stacked = torch.zeros( ( max_loras, - self.base_layer.org_vocab_size + lora_config.lora_extra_vocab_size, + self.base_layer.org_vocab_size, lora_config.max_lora_rank, ), dtype=lora_config.lora_dtype, @@ -82,14 +73,12 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 self.lora_b_stacked[index] = 0 - self.embeddings_tensors[index] = 0 def set_lora( self, index: int, lora_a: torch.Tensor, lora_b: torch.Tensor, - embeddings_tensor: torch.Tensor | None, ): self.reset_lora(index) # NOTE self.lora_a_stacked is row-major, and lora_a is col-major, @@ -100,36 +89,18 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): self.lora_b_stacked[index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_( lora_b, non_blocking=True ) - if embeddings_tensor is not None: - self.embeddings_tensors[ - index, - : embeddings_tensor.shape[0], - : embeddings_tensor.shape[1], - ].copy_(embeddings_tensor, non_blocking=True) - if self.embeddings_slice is not None: - # TODO(yard1): Optimize this copy, we don't need to copy - # everything, just the modified part - embeddings = self.embeddings_tensors.view( - self.embeddings_tensors.shape[0] * self.embeddings_tensors.shape[1], - self.embeddings_tensors.shape[2], - )[self.embeddings_slice[0] : self.embeddings_slice[1]] - assert self.embeddings_weights is not None - self.embeddings_weights[: embeddings.shape[0]].copy_(embeddings) def forward(self, x: torch.Tensor) -> torch.Tensor: - added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1, 1, 0) - # NB: Don't use torch.narrow here. torch.narrow triggers some # Dynamic Shape specialization in torch.compile num_tokens = x.shape[0] indices_1 = self.punica_wrapper._embeddings_indices[1][:num_tokens] - indices_0 = self.punica_wrapper._embeddings_indices[0][:num_tokens] full_lora_a_embeddings = F.embedding( x + indices_1, self.lora_a_stacked_2d, ) - full_output = self.base_layer.forward(x + (indices_0 * added_tokens_mask)) + full_output = self.base_layer.forward(x) full_output_org = full_output if full_output.ndim == 3: diff --git a/vllm/lora/lora_weights.py b/vllm/lora/lora_weights.py index 7691481d5039e..f0d8e22194050 100644 --- a/vllm/lora/lora_weights.py +++ b/vllm/lora/lora_weights.py @@ -21,7 +21,6 @@ class LoRALayerWeights: lora_alpha: int, lora_a: torch.Tensor, lora_b: torch.Tensor, - embeddings_tensor: torch.Tensor | None = None, scaling: float | None = None, ) -> None: self.module_name = module_name @@ -29,7 +28,6 @@ class LoRALayerWeights: self.lora_alpha = lora_alpha self.lora_a = lora_a self.lora_b = lora_b - self.embeddings_tensor = embeddings_tensor if scaling is None: self.scaling = self.lora_alpha / self.rank @@ -56,18 +54,11 @@ class LoRALayerWeights: def is_packed(self) -> bool: return False - @property - def extra_vocab_size(self) -> int: - return ( - self.embeddings_tensor.shape[0] if self.embeddings_tensor is not None else 0 - ) - @classmethod def from_config( cls, module_name: str, peft_helper: PEFTHelper, - embeddings_tensor: torch.Tensor | None = None, ) -> "LoRALayerWeights": # lora_a and lora_b are set to None for config-based construction return cls( @@ -76,7 +67,6 @@ class LoRALayerWeights: peft_helper.lora_alpha, None, None, - embeddings_tensor, peft_helper.vllm_lora_scaling_factor, ) @@ -89,7 +79,6 @@ class LoRALayerWeights: rank: int, dtype: torch.dtype, device: torch.types.Device, - embeddings_tensor_dim: int | None = None, ) -> "LoRALayerWeights": pin_memory = str(device) == "cpu" and is_pin_memory_available() lora_a = torch.zeros( @@ -99,24 +88,12 @@ class LoRALayerWeights: [output_dim, rank], dtype=dtype, device=device, pin_memory=pin_memory ) - embeddings_tensor = ( - torch.rand( - 10, - embeddings_tensor_dim, - dtype=dtype, - device=device, - pin_memory=pin_memory, - ) - if embeddings_tensor_dim - else None - ) return cls( module_name, rank=rank, lora_alpha=1, lora_a=lora_a, lora_b=lora_b, - embeddings_tensor=embeddings_tensor, ) @@ -139,7 +116,6 @@ class PackedLoRALayerWeights(LoRALayerWeights): lora_a=lora_a, lora_b=lora_b, scaling=scaling, # type: ignore - embeddings_tensor=None, ) self.lora_alphas = lora_alphas if scaling is None: diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 02c252f15bfab..eb11cd0afc487 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -21,6 +21,7 @@ from vllm.lora.utils import ( from_layer, from_layer_logits_processor, get_supported_lora_modules, + is_base_embeddding_weights, is_regex_target_modules, parse_fine_tuned_lora_name, process_packed_modules_mapping, @@ -93,14 +94,6 @@ class LoRAModel: loras=self.loras.copy(), ) - @property - def extra_vocab_size(self) -> int: - return ( - max(lora.extra_vocab_size for lora in self.loras.values()) - if self.loras - else 0 - ) - def get_lora(self, module_name: str) -> LoRALayerWeights | None: """Get LoRA for a given module by name""" return self.loras.get(module_name, None) @@ -117,7 +110,6 @@ class LoRAModel: peft_helper: PEFTHelper, device: str = "cuda", dtype: torch.dtype | None = None, - embeddings: dict[str, torch.Tensor] | None = None, target_embedding_padding: int | None = None, embedding_modules: dict[str, str] | None = None, embedding_padding_modules: list[str] | None = None, @@ -127,24 +119,14 @@ class LoRAModel: pin_memory = str(device) == "cpu" and is_pin_memory_available() loras: dict[str, LoRALayerWeights] = {} for tensor_name, tensor in tensors.items(): + if is_base_embeddding_weights(tensor_name): + continue module_name, is_lora_a = parse_fine_tuned_lora_name( tensor_name, weights_mapper ) if module_name not in loras: - lora_embeddings_tensor = None - if embeddings: - assert embedding_modules is not None - embeddings_module = next( - (k for k in embedding_modules if k in module_name), None - ) - if embeddings_module: - lora_embeddings_tensor = embeddings[ - embedding_modules[embeddings_module] - ].to(device=device, dtype=dtype) - if pin_memory: - lora_embeddings_tensor = lora_embeddings_tensor.pin_memory() loras[module_name] = LoRALayerWeights.from_config( - module_name, peft_helper, lora_embeddings_tensor + module_name, peft_helper ) if is_lora_a: @@ -206,15 +188,17 @@ class LoRAModel: lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors") lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin") lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt") - new_embeddings_tensor_path = os.path.join( - lora_dir, "new_embeddings.safetensors" - ) - new_embeddings_bin_file_path = os.path.join(lora_dir, "new_embeddings.bin") + # new_embeddings_tensor_path = os.path.join( + # lora_dir, "new_embeddings.safetensors" + # ) + # new_embeddings_bin_file_path = os.path.join(lora_dir, "new_embeddings.bin") tensors: dict[str, torch.Tensor] = {} unexpected_modules: list[list[str] | str] = [] def check_unexpected_modules(modules: dict): for lora_module in modules.keys(): # noqa + if is_base_embeddding_weights(lora_module): + continue module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper) # Handle FSDP file format where experts.base_layer is the # gate_up_proj and experts is the down_proj @@ -300,21 +284,12 @@ class LoRAModel: else: raise ValueError(f"{lora_dir} doesn't contain tensors") - embeddings = None - if os.path.isfile(new_embeddings_tensor_path): - embeddings = safetensors.torch.load_file(new_embeddings_tensor_path) - elif os.path.isfile(new_embeddings_bin_file_path): - embeddings = torch.load( - new_embeddings_bin_file_path, map_location=device, weights_only=True - ) - return cls.from_lora_tensors( lora_model_id=get_lora_id() if lora_model_id is None else lora_model_id, tensors=tensors, peft_helper=peft_helper, device=device, dtype=dtype, - embeddings=embeddings, target_embedding_padding=target_embedding_padding, embedding_modules=embedding_modules, embedding_padding_modules=embedding_padding_modules, @@ -474,7 +449,6 @@ class LoRAModelManager: index, module_lora.lora_a, module_lora.lora_b, - module_lora.embeddings_tensor, ) else: module.reset_lora(index) @@ -505,7 +479,6 @@ class LoRAModelManager: self.lora_index_to_id, self.lora_slots + 1, self.vocab_size, - self.lora_config.lora_extra_vocab_size, ) def remove_all_adapters(self): @@ -616,7 +589,6 @@ class LoRAModelManager: if parts[-1] in embedding_modules: input_dim = ( module.base_layer.org_vocab_size - + self.lora_config.lora_extra_vocab_size if hasattr(module.base_layer, "org_vocab_size") else module.base_layer.weight.shape[1] ) @@ -625,11 +597,6 @@ class LoRAModelManager: if hasattr(module.base_layer, "embedding_dim") else module.base_layer.weight.shape[0] ) - embeddings_tensor_dim = ( - module.base_layer.embedding_dim - if hasattr(module.base_layer, "embedding_dim") - else module.base_layer.weight.shape[1] - ) lora = LoRALayerWeights.create_dummy_lora_weights( module_name, input_dim, @@ -637,7 +604,6 @@ class LoRAModelManager: rank, module.lora_a_stacked[0].dtype, "cpu", - embeddings_tensor_dim=embeddings_tensor_dim, ) else: lora = LoRALayerWeights.create_dummy_lora_weights( diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index a6ffbb7b71ce4..7c0fc8167711d 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -31,7 +31,6 @@ class PunicaWrapperABC(ABC): lora_index_to_id: list[int | None], max_loras: int, vocab_size: int, - extra_vocab_size: int, **kwargs, ) -> None: """ @@ -172,8 +171,11 @@ class PunicaWrapperBase(PunicaWrapperABC): lora_index_to_id: list[int | None], max_loras: int, vocab_size: int, - extra_vocab_size: int, ): + # NOTE We have remove lora extra vocab support for now. So we set + # extra_vocab_size alwayzs to 0, and extra_vocab_size will be removed. + + extra_vocab_size = 0 ( base_indices, sampler_indices, @@ -285,12 +287,9 @@ class PunicaWrapperBase(PunicaWrapperABC): lora_index_to_id: list[int | None], max_loras: int, vocab_size: int, - extra_vocab_size: int, **kwargs, ): - self._update_base_metadata( - mapping, lora_index_to_id, max_loras, vocab_size, extra_vocab_size - ) + self._update_base_metadata(mapping, lora_index_to_id, max_loras, vocab_size) if mapping.is_prefill: # Update metadata required for prefill-related operators. diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index d863a5884d3c5..52138ef0cc3b0 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -65,13 +65,10 @@ class PunicaWrapperGPU(PunicaWrapperBase): lora_index_to_id: list[int | None], max_loras: int, vocab_size: int, - extra_vocab_size: int, **kwargs, ): self.is_prefill = mapping.is_prefill - self._update_base_metadata( - mapping, lora_index_to_id, max_loras, vocab_size, extra_vocab_size - ) + self._update_base_metadata(mapping, lora_index_to_id, max_loras, vocab_size) # Prepare cuda kernel metadata tensors self.token_mapping_meta.prepare_tensors(self.token_lora_indices) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 090878dcd2546..0888772db54e7 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -292,7 +292,6 @@ class PunicaWrapperTPU(PunicaWrapperBase): lora_index_to_id: list[int | None], max_loras: int, vocab_size: int, - extra_vocab_size: int, ): # Make sure we don't accidentally collect outside operations torch_xla.sync() @@ -313,7 +312,7 @@ class PunicaWrapperTPU(PunicaWrapperBase): lora_index_to_id, max_loras, vocab_size, - extra_vocab_size, + 0, # extra_vocab_size "cpu", ) self._token_lora_indices = self._pad_to_shape( diff --git a/vllm/lora/punica_wrapper/punica_xpu.py b/vllm/lora/punica_wrapper/punica_xpu.py index b95087d0ff834..00c00782896cf 100644 --- a/vllm/lora/punica_wrapper/punica_xpu.py +++ b/vllm/lora/punica_wrapper/punica_xpu.py @@ -43,13 +43,10 @@ class PunicaWrapperXPU(PunicaWrapperBase): lora_index_to_id: list[int | None], max_loras: int, vocab_size: int, - extra_vocab_size: int, **kwargs, ): self.is_prefill = mapping.is_prefill - self._update_base_metadata( - mapping, lora_index_to_id, max_loras, vocab_size, extra_vocab_size - ) + self._update_base_metadata(mapping, lora_index_to_id, max_loras, vocab_size) def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor: return torch.narrow(self._token_lora_indices, 0, 0, x.size(0)) diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 0f43ff06d8f2b..a49a7d9d1669d 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -166,6 +166,16 @@ def parse_fine_tuned_lora_name( raise ValueError(f"{name} is unsupported LoRA weight") +def is_base_embeddding_weights(name: str) -> bool: + # hardcoded subfixes for input & output embedding weights + input_embedding_subfix = ".embed_tokens.base_layer.weight" + output_embedding_subfix = ".lm_head.base_layer.weight" + + return name.endswith(input_embedding_subfix) or name.endswith( + output_embedding_subfix + ) + + def is_regex_target_modules( load_modules: str | list[str], expected_lora_modules: list[str] ) -> bool: diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index b85151f2c7592..4cc201a6414f1 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -121,8 +121,7 @@ class WorkerLoRAManager: lora_model_id=lora_request.lora_int_id, device="cpu", dtype=self.lora_config.lora_dtype, - target_embedding_padding=self.vocab_size - + self.lora_config.lora_extra_vocab_size, + target_embedding_padding=self.vocab_size, embedding_modules=self.embedding_modules, embedding_padding_modules=self.embedding_padding_modules, tensorizer_config_dict=lora_request.tensorizer_config_dict, @@ -143,12 +142,6 @@ class WorkerLoRAManager: # For BadRequestError raise e - if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size: - raise ValueError( - f"LoRA added vocab size {lora.extra_vocab_size} " - f"is greater than lora_extra_vocab_size " - f"{self.lora_config.lora_extra_vocab_size}." - ) return lora def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index 1dc205b47753d..cd7ce2fc8f00a 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -46,7 +46,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -261,29 +260,16 @@ class GraniteModel(nn.Module): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + if get_pp_group().is_first_rank or ( config.tie_word_embeddings and get_pp_group().is_last_rank ): self.embed_tokens = VocabParallelEmbedding( - self.vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, quant_config=quant_config, ) else: @@ -420,28 +406,18 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config + self.quant_config = quant_config self.model = GraniteModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) if get_pp_group().is_last_rank: - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -453,7 +429,7 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): logit_scale /= config.logits_scaling self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, scale=logit_scale + config.vocab_size, scale=logit_scale ) else: self.lm_head = PPMissingLayer() diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index d5b49d2fb4c26..ebf8addda4a54 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -47,7 +47,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -368,24 +367,18 @@ class LlamaModel(nn.Module): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size + if get_pp_group().is_first_rank or ( config.tie_word_embeddings and get_pp_group().is_last_rank ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, quant_config=quant_config, ) else: @@ -562,9 +555,7 @@ class LlamaForCausalLM( super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config self.model = self._init_model( vllm_config=vllm_config, @@ -573,20 +564,9 @@ class LlamaForCausalLM( ) if get_pp_group().is_last_rank: - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=( - DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size - ), quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -595,7 +575,7 @@ class LlamaForCausalLM( logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, logit_scale + config.vocab_size, scale=logit_scale ) else: self.lm_head = PPMissingLayer() diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 54ab8dd493e73..0a9c3f136964e 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -51,7 +51,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -301,23 +300,18 @@ class MixtralModel(nn.Module): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config + parallel_config = vllm_config.parallel_config self.config = config self.quant_config = quant_config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab + + self.vocab_size = config.vocab_size self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, ) self.enable_eplb = parallel_config.enable_eplb @@ -508,34 +502,24 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config + self.config = config - self.lora_config = lora_config + self.quant_config = quant_config self.model = MixtralModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors ) diff --git a/vllm/model_executor/models/teleflm.py b/vllm/model_executor/models/teleflm.py index 8a0bec9dff848..bebd7bcaa9249 100644 --- a/vllm/model_executor/models/teleflm.py +++ b/vllm/model_executor/models/teleflm.py @@ -74,5 +74,5 @@ class TeleFLMForCausalLM(LlamaForCausalLM): self.output_mult = self.config.output_mult / self.mup_scale_factor logit_scale = self.output_mult self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, self.config.vocab_size, logit_scale + self.config.vocab_size, scale=logit_scale ) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index e9eb7cad38f88..923c31c187f31 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -219,9 +219,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.hidden_size = model_config.get_hidden_size() self.vocab_size = model_config.get_vocab_size() - if self.lora_config is not None: - self.vocab_size += self.lora_config.lora_extra_vocab_size - # Multi-modal data support self.mm_registry = MULTIMODAL_REGISTRY self.uses_mrope = model_config.uses_mrope