mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 13:35:48 +08:00
[Bugfix] Fix JambaForCausalLM LoRA (#14370)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
e5e03c2c1b
commit
ddd1ef66ec
@ -6,7 +6,6 @@ from typing import TypedDict
|
|||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import safetensors
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
@ -191,29 +190,6 @@ def mixtral_lora_files_all_target_modules():
|
|||||||
return snapshot_download(repo_id="dyang415/mixtral-lora-v0")
|
return snapshot_download(repo_id="dyang415/mixtral-lora-v0")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def jamba_lora_files():
|
|
||||||
# some of the adapters have unnecessary weights for serving,
|
|
||||||
# hence we remove them
|
|
||||||
def remove_unnecessary_weights(path):
|
|
||||||
lora_path = f"{adapter_path}/adapter_model.safetensors"
|
|
||||||
tensors = safetensors.torch.load_file(lora_path)
|
|
||||||
nonlora_keys = []
|
|
||||||
for k in list(tensors.keys()):
|
|
||||||
if "lora" not in k:
|
|
||||||
nonlora_keys.append(k)
|
|
||||||
for k in nonlora_keys:
|
|
||||||
del tensors[k]
|
|
||||||
safetensors.torch.save_file(tensors, lora_path)
|
|
||||||
|
|
||||||
adapter_path = snapshot_download(
|
|
||||||
repo_id=
|
|
||||||
"hf-100/Jamba-1.5-mini-Spellbound-StoryWriter-0.1-6583896-ckpt53-lora")
|
|
||||||
|
|
||||||
remove_unnecessary_weights(adapter_path)
|
|
||||||
return adapter_path
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def gemma_lora_files():
|
def gemma_lora_files():
|
||||||
return snapshot_download(repo_id="wskwon/gemma-7b-test-lora")
|
return snapshot_download(repo_id="wskwon/gemma-7b-test-lora")
|
||||||
|
|||||||
@ -1,54 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import vllm
|
|
||||||
from vllm.lora.request import LoRARequest
|
|
||||||
|
|
||||||
MODEL_PATH = "ai21labs/AI21-Jamba-1.5-Mini"
|
|
||||||
|
|
||||||
MAX_TOKENS = 40
|
|
||||||
|
|
||||||
|
|
||||||
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int,
|
|
||||||
prompts: list[str]) -> list[str]:
|
|
||||||
|
|
||||||
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=MAX_TOKENS)
|
|
||||||
outputs = llm.generate(
|
|
||||||
prompts,
|
|
||||||
sampling_params,
|
|
||||||
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
|
|
||||||
if lora_id else None)
|
|
||||||
# Print the outputs.
|
|
||||||
generated_texts: list[str] = []
|
|
||||||
for output in outputs:
|
|
||||||
prompt = output.prompt
|
|
||||||
generated_text = output.outputs[0].text.strip()
|
|
||||||
generated_texts.append(generated_text)
|
|
||||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|
||||||
return generated_texts
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("tp_size", [4])
|
|
||||||
def test_jamba_lora(jamba_lora_files, tp_size):
|
|
||||||
"""Original test, the LoRA model has the common target modules, not all"""
|
|
||||||
if torch.cuda.device_count() < tp_size:
|
|
||||||
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
|
|
||||||
|
|
||||||
prompts = ["Write a story about a sheep and a goat."]
|
|
||||||
|
|
||||||
llm = vllm.LLM(
|
|
||||||
MODEL_PATH,
|
|
||||||
enable_lora=True,
|
|
||||||
max_num_seqs=16,
|
|
||||||
max_loras=4,
|
|
||||||
distributed_executor_backend="ray",
|
|
||||||
tensor_parallel_size=tp_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
expected_jamba_output = [
|
|
||||||
"""Once upon a time, in a lush green meadow, there lived a sheep named Clara and a goat named Billy. Clara was a gentle creature, always nibbling on the soft grass and humming""" # noqa: E501
|
|
||||||
]
|
|
||||||
assert do_sample(llm, jamba_lora_files, lora_id=1,
|
|
||||||
prompts=prompts) == expected_jamba_output
|
|
||||||
@ -632,6 +632,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage,
|
|||||||
|
|
||||||
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||||
linear, lora_linear = create_random_linear_replicated_layer()
|
linear, lora_linear = create_random_linear_replicated_layer()
|
||||||
|
assert torch.equal(linear.weight, lora_linear.weight)
|
||||||
lora_linear.set_mapping(punica_wrapper)
|
lora_linear.set_mapping(punica_wrapper)
|
||||||
lora_dict, _ = populate_loras(
|
lora_dict, _ = populate_loras(
|
||||||
id_to_index,
|
id_to_index,
|
||||||
@ -757,6 +758,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
|||||||
|
|
||||||
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||||
linear, lora_linear = create_random_linear_parallel_layer()
|
linear, lora_linear = create_random_linear_parallel_layer()
|
||||||
|
assert torch.equal(linear.weight, lora_linear.weight)
|
||||||
lora_linear.set_mapping(punica_wrapper)
|
lora_linear.set_mapping(punica_wrapper)
|
||||||
lora_dict, _ = populate_loras(
|
lora_dict, _ = populate_loras(
|
||||||
id_to_index,
|
id_to_index,
|
||||||
@ -904,6 +906,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
|||||||
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||||
|
|
||||||
linear, lora_linear = create_column_parallel_packed_layer()
|
linear, lora_linear = create_column_parallel_packed_layer()
|
||||||
|
assert torch.equal(linear.weight, lora_linear.weight)
|
||||||
lora_linear.set_mapping(punica_wrapper)
|
lora_linear.set_mapping(punica_wrapper)
|
||||||
lora_dict, sublora_dict = populate_loras(
|
lora_dict, sublora_dict = populate_loras(
|
||||||
id_to_index,
|
id_to_index,
|
||||||
|
|||||||
@ -274,6 +274,10 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
return type(source_layer) is VocabParallelEmbedding
|
return type(source_layer) is VocabParallelEmbedding
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weight(self):
|
||||||
|
return self.base_layer.weight
|
||||||
|
|
||||||
|
|
||||||
class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
|
class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
|
||||||
|
|
||||||
@ -409,6 +413,34 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
|
|||||||
self.output_slices)
|
self.output_slices)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weight(self) -> torch.Tensor:
|
||||||
|
|
||||||
|
# unquantizedLinear
|
||||||
|
if hasattr(self.base_layer, "weight"):
|
||||||
|
return self.base_layer.weight
|
||||||
|
# Compressed Tensor
|
||||||
|
elif hasattr(self.base_layer, "weight_packed"):
|
||||||
|
return self.base_layer.weight_packed
|
||||||
|
# GPTQ/AWQ
|
||||||
|
elif hasattr(self.base_layer, "qweight"):
|
||||||
|
return self.base_layer.qweight
|
||||||
|
# marlin
|
||||||
|
elif hasattr(self.base_layer, "B"):
|
||||||
|
return self.base_layer.B
|
||||||
|
# HQQ marlin
|
||||||
|
elif hasattr(self.base_layer, "W_q"):
|
||||||
|
return self.base_layer.W_q
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported base layer: {self.base_layer}")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bias(self) -> Optional[torch.Tensor]:
|
||||||
|
if hasattr(self.base_layer, "bias"):
|
||||||
|
return self.base_layer.bias
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
|
class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||||
|
|
||||||
@ -902,11 +934,6 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
|||||||
|
|
||||||
return output, output_bias
|
return output, output_bias
|
||||||
|
|
||||||
@property
|
|
||||||
def weight(self):
|
|
||||||
return (self.base_layer.weight if hasattr(self.base_layer, "weight")
|
|
||||||
else self.base_layer.qweight)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@_not_fully_sharded_can_replace
|
@_not_fully_sharded_can_replace
|
||||||
def can_replace_layer(
|
def can_replace_layer(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user