[Bugfix] Fix JambaForCausalLM LoRA (#14370)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-03-07 14:05:47 +08:00 committed by GitHub
parent e5e03c2c1b
commit ddd1ef66ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 35 additions and 83 deletions

View File

@ -6,7 +6,6 @@ from typing import TypedDict
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
import safetensors
import torch import torch
import torch.nn as nn import torch.nn as nn
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
@ -191,29 +190,6 @@ def mixtral_lora_files_all_target_modules():
return snapshot_download(repo_id="dyang415/mixtral-lora-v0") return snapshot_download(repo_id="dyang415/mixtral-lora-v0")
@pytest.fixture(scope="session")
def jamba_lora_files():
# some of the adapters have unnecessary weights for serving,
# hence we remove them
def remove_unnecessary_weights(path):
lora_path = f"{adapter_path}/adapter_model.safetensors"
tensors = safetensors.torch.load_file(lora_path)
nonlora_keys = []
for k in list(tensors.keys()):
if "lora" not in k:
nonlora_keys.append(k)
for k in nonlora_keys:
del tensors[k]
safetensors.torch.save_file(tensors, lora_path)
adapter_path = snapshot_download(
repo_id=
"hf-100/Jamba-1.5-mini-Spellbound-StoryWriter-0.1-6583896-ckpt53-lora")
remove_unnecessary_weights(adapter_path)
return adapter_path
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def gemma_lora_files(): def gemma_lora_files():
return snapshot_download(repo_id="wskwon/gemma-7b-test-lora") return snapshot_download(repo_id="wskwon/gemma-7b-test-lora")

View File

@ -1,54 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
import vllm
from vllm.lora.request import LoRARequest
MODEL_PATH = "ai21labs/AI21-Jamba-1.5-Mini"
MAX_TOKENS = 40
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int,
prompts: list[str]) -> list[str]:
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=MAX_TOKENS)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts: list[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
@pytest.mark.parametrize("tp_size", [4])
def test_jamba_lora(jamba_lora_files, tp_size):
"""Original test, the LoRA model has the common target modules, not all"""
if torch.cuda.device_count() < tp_size:
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
prompts = ["Write a story about a sheep and a goat."]
llm = vllm.LLM(
MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
distributed_executor_backend="ray",
tensor_parallel_size=tp_size,
)
expected_jamba_output = [
"""Once upon a time, in a lush green meadow, there lived a sheep named Clara and a goat named Billy. Clara was a gentle creature, always nibbling on the soft grass and humming""" # noqa: E501
]
assert do_sample(llm, jamba_lora_files, lora_id=1,
prompts=prompts) == expected_jamba_output

View File

@ -632,6 +632,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage,
id_to_index = get_random_id_to_index(num_loras, max_loras) id_to_index = get_random_id_to_index(num_loras, max_loras)
linear, lora_linear = create_random_linear_replicated_layer() linear, lora_linear = create_random_linear_replicated_layer()
assert torch.equal(linear.weight, lora_linear.weight)
lora_linear.set_mapping(punica_wrapper) lora_linear.set_mapping(punica_wrapper)
lora_dict, _ = populate_loras( lora_dict, _ = populate_loras(
id_to_index, id_to_index,
@ -757,6 +758,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
id_to_index = get_random_id_to_index(num_loras, max_loras) id_to_index = get_random_id_to_index(num_loras, max_loras)
linear, lora_linear = create_random_linear_parallel_layer() linear, lora_linear = create_random_linear_parallel_layer()
assert torch.equal(linear.weight, lora_linear.weight)
lora_linear.set_mapping(punica_wrapper) lora_linear.set_mapping(punica_wrapper)
lora_dict, _ = populate_loras( lora_dict, _ = populate_loras(
id_to_index, id_to_index,
@ -904,6 +906,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
id_to_index = get_random_id_to_index(num_loras, max_loras) id_to_index = get_random_id_to_index(num_loras, max_loras)
linear, lora_linear = create_column_parallel_packed_layer() linear, lora_linear = create_column_parallel_packed_layer()
assert torch.equal(linear.weight, lora_linear.weight)
lora_linear.set_mapping(punica_wrapper) lora_linear.set_mapping(punica_wrapper)
lora_dict, sublora_dict = populate_loras( lora_dict, sublora_dict = populate_loras(
id_to_index, id_to_index,

View File

@ -274,6 +274,10 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
) -> bool: ) -> bool:
return type(source_layer) is VocabParallelEmbedding return type(source_layer) is VocabParallelEmbedding
@property
def weight(self):
return self.base_layer.weight
class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
@ -409,6 +413,34 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
self.output_slices) self.output_slices)
return output return output
@property
def weight(self) -> torch.Tensor:
# unquantizedLinear
if hasattr(self.base_layer, "weight"):
return self.base_layer.weight
# Compressed Tensor
elif hasattr(self.base_layer, "weight_packed"):
return self.base_layer.weight_packed
# GPTQ/AWQ
elif hasattr(self.base_layer, "qweight"):
return self.base_layer.qweight
# marlin
elif hasattr(self.base_layer, "B"):
return self.base_layer.B
# HQQ marlin
elif hasattr(self.base_layer, "W_q"):
return self.base_layer.W_q
else:
raise ValueError(f"Unsupported base layer: {self.base_layer}")
@property
def bias(self) -> Optional[torch.Tensor]:
if hasattr(self.base_layer, "bias"):
return self.base_layer.bias
else:
return None
class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA): class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
@ -902,11 +934,6 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
return output, output_bias return output, output_bias
@property
def weight(self):
return (self.base_layer.weight if hasattr(self.base_layer, "weight")
else self.base_layer.qweight)
@classmethod @classmethod
@_not_fully_sharded_can_replace @_not_fully_sharded_can_replace
def can_replace_layer( def can_replace_layer(