[Model] Add Phi-2 LoRA support (#4886)

This commit is contained in:
Isotr0py 2024-05-21 13:24:17 +08:00 committed by GitHub
parent d130b573a0
commit f12c3b5b3d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 99 additions and 6 deletions

View File

@ -118,7 +118,7 @@ Alongside each architecture, we include some popular models that use it.
* - :code:`PhiForCausalLM` * - :code:`PhiForCausalLM`
- Phi - Phi
- :code:`microsoft/phi-1_5`, :code:`microsoft/phi-2`, etc. - :code:`microsoft/phi-1_5`, :code:`microsoft/phi-2`, etc.
- - ✅︎
* - :code:`Phi3ForCausalLM` * - :code:`Phi3ForCausalLM`
- Phi-3 - Phi-3
- :code:`microsoft/Phi-3-mini-4k-instruct`, :code:`microsoft/Phi-3-mini-128k-instruct`, etc. - :code:`microsoft/Phi-3-mini-4k-instruct`, :code:`microsoft/Phi-3-mini-128k-instruct`, etc.

View File

@ -165,6 +165,11 @@ def tinyllama_lora_files():
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora") return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")
@pytest.fixture(scope="session")
def phi2_lora_files():
return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora")
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def long_context_lora_files_16k_1(): def long_context_lora_files_16k_1():
return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_1") return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_1")

67
tests/lora/test_phi.py Normal file
View File

@ -0,0 +1,67 @@
import vllm
from vllm.lora.request import LoRARequest
MODEL_PATH = "microsoft/phi-2"
PROMPT_TEMPLATE = "### Instruct: {sql_prompt}\n\n### Context: {context}\n\n### Output:" # noqa: E501
def do_sample(llm, lora_path: str, lora_id: int) -> str:
prompts = [
PROMPT_TEMPLATE.format(
sql_prompt=
"Which catalog publisher has published the most catalogs?",
context="CREATE TABLE catalogs (catalog_publisher VARCHAR);"),
PROMPT_TEMPLATE.format(
sql_prompt=
"Which trip started from the station with the largest dock count? Give me the trip id.", # noqa: E501
context=
"CREATE TABLE trip (id VARCHAR, start_station_id VARCHAR); CREATE TABLE station (id VARCHAR, dock_count VARCHAR);" # noqa: E501
),
PROMPT_TEMPLATE.format(
sql_prompt=
"How many marine species are found in the Southern Ocean?", # noqa: E501
context=
"CREATE TABLE marine_species (name VARCHAR(50), common_name VARCHAR(50), location VARCHAR(50));" # noqa: E501
),
]
sampling_params = vllm.SamplingParams(temperature=0,
max_tokens=64,
stop="### End")
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None,
)
# Print the outputs.
generated_texts = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
def test_phi2_lora(phi2_lora_files):
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
# Otherwise, the lora-test will fail due to CUDA OOM.
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=2,
enforce_eager=True)
expected_lora_output = [
"SELECT catalog_publisher, COUNT(*) as num_catalogs FROM catalogs GROUP BY catalog_publisher ORDER BY num_catalogs DESC LIMIT 1;", # noqa: E501
"SELECT trip.id FROM trip JOIN station ON trip.start_station_id = station.id WHERE station.dock_count = (SELECT MAX(dock_count) FROM station);", # noqa: E501
"SELECT COUNT(*) FROM marine_species WHERE location = 'Southern Ocean';", # noqa: E501
]
output1 = do_sample(llm, phi2_lora_files, lora_id=1)
for i in range(len(expected_lora_output)):
assert output1[i].startswith(expected_lora_output[i])
output2 = do_sample(llm, phi2_lora_files, lora_id=2)
for i in range(len(expected_lora_output)):
assert output2[i].startswith(expected_lora_output[i])

View File

@ -42,7 +42,7 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -229,11 +229,32 @@ class PhiModel(nn.Module):
class PhiForCausalLM(nn.Module): class PhiForCausalLM(nn.Module):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
]
}
def __init__(self, # LoRA specific attributes
config: PretrainedConfig, supported_lora_modules = [
cache_config: Optional[CacheConfig] = None, "qkv_proj",
quant_config: Optional[QuantizationConfig] = None): "dense",
"fc1",
"fc2",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
del lora_config # Unused.
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config