From 1096717ae9e0b414ad625c1a12354dd1d949ffb1 Mon Sep 17 00:00:00 2001 From: Jee Li Date: Fri, 12 Apr 2024 12:02:44 +0800 Subject: [PATCH] [Core] Support LoRA on quantized models (#4012) --- tests/lora/conftest.py | 5 + tests/lora/test_quant_model.py | 179 +++++++++++++++++++++++++++++++++ vllm/config.py | 9 +- vllm/lora/layers.py | 67 +++++++----- 4 files changed, 234 insertions(+), 26 deletions(-) create mode 100644 tests/lora/test_quant_model.py diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 207c635e2dc86..1127cc33183c9 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -143,6 +143,11 @@ def baichuan_lora_files(): return snapshot_download(repo_id="jeeejeee/baichuan7b-text2sql-spider") +@pytest.fixture(scope="session") +def tinyllama_lora_files(): + return snapshot_download(repo_id="jashing/tinyllama-colorist-lora") + + @pytest.fixture def llama_2_7b_engine_extra_embeddings() -> nn.Module: cleanup() diff --git a/tests/lora/test_quant_model.py b/tests/lora/test_quant_model.py new file mode 100644 index 0000000000000..3d86a4366aa57 --- /dev/null +++ b/tests/lora/test_quant_model.py @@ -0,0 +1,179 @@ +# Adapted from +# https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/tests/lora/test_llama.py +from dataclasses import dataclass +from typing import List + +import pytest + +import vllm +from vllm.lora.request import LoRARequest + +from .conftest import cleanup + + +@dataclass +class ModelWithQuantization: + model_path: str + quantization: str + + +MODELS: List[ModelWithQuantization] = [ + ModelWithQuantization(model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", + quantization="AWQ"), + ModelWithQuantization(model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", + quantization="GPTQ"), +] + + +def do_sample(llm, lora_path: str, lora_id: int, max_tokens=256): + raw_prompts = [ + "Give me an orange-ish brown color", + "Give me a neon pink color", + ] + + def format_prompt_tuples(prompt): + return f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" + + prompts = [format_prompt_tuples(p) for p in raw_prompts] + + sampling_params = vllm.SamplingParams(temperature=0, + max_tokens=max_tokens, + stop=["<|im_end|>"]) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None) + # Print the outputs. + generated_texts = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("tp_size", [1]) +def test_quant_model_lora(tinyllama_lora_files, model, tp_size): + # Cannot use as it will initialize torch.cuda too early... + # if torch.cuda.device_count() < tp_size: + # pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") + + llm = vllm.LLM(model=model.model_path, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + max_model_len=400, + tensor_parallel_size=tp_size, + quantization=model.quantization, + trust_remote_code=True) + + if model.quantization is None: + expected_no_lora_output = [ + "Here are some examples of orange-brown colors", + "I'm sorry, I don't have" + ] + expected_lora_output = [ + "#ff8050", + "#ff8080", + ] + elif model.quantization == "AWQ": + expected_no_lora_output = [ + "I'm sorry, I don't understand", + "I'm sorry, I don't understand", + ] + expected_lora_output = [ + "#f07700: A v", + "#f00000: A v", + ] + elif model.quantization == "GPTQ": + expected_no_lora_output = [ + "I'm sorry, I don't have", + "I'm sorry, I don't have", + ] + expected_lora_output = [ + "#f08800: This is", + "#f07788 \n#", + ] + + def expect_match(output, expected_output): + # HACK: GPTQ lora outputs are just incredibly unstable. + # Assert that the outputs changed. + if (model.quantization == "GPTQ" + and expected_output is expected_lora_output): + assert output != expected_no_lora_output + for i, o in enumerate(output): + assert o.startswith( + '#'), f"Expected example {i} to start with # but got {o}" + return + assert output == expected_output + + max_tokens = 10 + + print("lora adapter created") + output = do_sample(llm, + tinyllama_lora_files, + lora_id=0, + max_tokens=max_tokens) + expect_match(output, expected_no_lora_output) + + print("lora 1") + output = do_sample(llm, + tinyllama_lora_files, + lora_id=1, + max_tokens=max_tokens) + expect_match(output, expected_lora_output) + + print("no lora") + output = do_sample(llm, + tinyllama_lora_files, + lora_id=0, + max_tokens=max_tokens) + expect_match(output, expected_no_lora_output) + + print("lora 2") + output = do_sample(llm, + tinyllama_lora_files, + lora_id=2, + max_tokens=max_tokens) + expect_match(output, expected_lora_output) + + print("removing lora") + + del llm + cleanup() + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.skip("Requires multiple GPUs") +def test_quant_model_tp_equality(tinyllama_lora_files, model): + # Cannot use as it will initialize torch.cuda too early... + # if torch.cuda.device_count() < 2: + # pytest.skip(f"Not enough GPUs for tensor parallelism {2}") + + llm_tp1 = vllm.LLM(model=model.model_path, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=1, + quantization=model.quantization, + trust_remote_code=True) + output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1) + + del llm_tp1 + cleanup() + + llm_tp2 = vllm.LLM(model=model.model_path, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=2, + quantization=model.quantization) + output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1) + + del llm_tp2 + cleanup() + + assert output_tp1 == output_tp2 diff --git a/vllm/config.py b/vllm/config.py index 4102edbe01d35..da7eb2810ff05 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -822,9 +822,12 @@ class LoRAConfig: self.lora_dtype = model_config.dtype elif isinstance(self.lora_dtype, str): self.lora_dtype = getattr(torch, self.lora_dtype) - if model_config.quantization is not None: - raise ValueError( - "LoRA is not supported with quantized models yet.") + if model_config.quantization and model_config.quantization not in [ + "awq", "gptq" + ]: + # TODO support marlin and squeezellm + logger.warning(f"{model_config.quantization} quantization is not " + "tested with LoRA yet.") def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): if scheduler_config.max_num_batched_tokens > 65528: diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 5456b5613c47a..4b9653de73a88 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -29,6 +29,19 @@ if TYPE_CHECKING: pass +def _get_lora_device(base_layer: nn.Module) -> torch.device: + # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34 + """Returns the device for where to place the LoRA tensors.""" + if hasattr(base_layer, "weight"): + return base_layer.weight.device + if hasattr(base_layer, "linear_weights") and isinstance( + base_layer.linear_weights, dict): + values = list(base_layer.linear_weights.values()) + if len(values) and isinstance(values[0], torch.Tensor): + return values[0].device + raise ValueError(f"Unsupported base layer: {base_layer}") + + def _apply_lora( x: torch.Tensor, lora_a_stacked: torch.Tensor, @@ -302,6 +315,9 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): super().__init__() self.base_layer = base_layer self.tp_size = get_tensor_model_parallel_world_size() + self.input_size = self.base_layer.input_size + self.output_size = self.base_layer.output_size_per_partition + self.device = _get_lora_device(self.base_layer) def create_lora_weights( self, @@ -312,17 +328,17 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): max_loras, 1, lora_config.max_lora_rank, - self.base_layer.weight.shape[1], + self.input_size, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ) self.lora_b_stacked = torch.zeros( max_loras, 1, - self.base_layer.weight.shape[0], + self.output_size, lora_config.max_lora_rank, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ) self.indices: Optional[torch.Tensor] = None @@ -442,18 +458,18 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): max_loras, 1, lora_config.max_lora_rank, - self.base_layer.weight.shape[1], + self.input_size, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ) for _ in range(n_slices)) self.lora_b_stacked = tuple( torch.zeros( max_loras, 1, - self.base_layer.weight.shape[0] // 2, + self.output_size // 2, lora_config.max_lora_rank, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ) for _ in range(n_slices)) self.indices: Optional[torch.Tensor] = None @@ -619,25 +635,25 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): max_loras, 1, lora_config.max_lora_rank, - self.base_layer.weight.shape[1], + self.input_size, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ), torch.zeros( max_loras, 1, lora_config.max_lora_rank, - self.base_layer.weight.shape[1], + self.input_size, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ), torch.zeros( max_loras, 1, lora_config.max_lora_rank, - self.base_layer.weight.shape[1], + self.input_size, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ), ) self.lora_b_stacked = ( @@ -647,7 +663,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): self.q_proj_shard_size, lora_config.max_lora_rank, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ), torch.zeros( max_loras, @@ -655,7 +671,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): self.kv_proj_shard_size, lora_config.max_lora_rank, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ), torch.zeros( max_loras, @@ -663,7 +679,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): self.kv_proj_shard_size, lora_config.max_lora_rank, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ), ) @@ -766,6 +782,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): def __init__(self, base_layer: RowParallelLinear) -> None: super().__init__() self.base_layer = base_layer + self.input_size = self.base_layer.input_size_per_partition + self.output_size = self.base_layer.output_size + self.device = _get_lora_device(self.base_layer) def create_lora_weights( self, @@ -777,20 +796,20 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): max_loras, 1, lora_config.max_lora_rank, - self.base_layer.weight.shape[1], + self.input_size, ), dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ) self.lora_b_stacked = torch.zeros( ( max_loras, 1, - self.base_layer.weight.shape[0], + self.output_size, lora_config.max_lora_rank, ), dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ) self.indices: Optional[torch.Tensor] = None self.indices_len: Optional[List[int]] = None @@ -809,7 +828,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): self.reset_lora(index) if self.base_layer.tp_size > 1: tensor_model_parallel_rank = get_tensor_model_parallel_rank() - shard_size = self.base_layer.weight.shape[1] + shard_size = self.input_size start_idx = tensor_model_parallel_rank * shard_size end_idx = (tensor_model_parallel_rank + 1) * shard_size lora_a = lora_a[start_idx:end_idx, :] @@ -884,7 +903,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): @property def weight(self): - return self.base_layer.weight + + return self.base_layer.weight if hasattr( + self.base_layer, "weight") else self.base_layer.qweight @classmethod def can_replace_layer(cls, source_layer: nn.Module,