mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 17:45:19 +08:00
[Model] Add Phi-2 LoRA support (#4886)
This commit is contained in:
parent
d130b573a0
commit
f12c3b5b3d
@ -118,7 +118,7 @@ Alongside each architecture, we include some popular models that use it.
|
||||
* - :code:`PhiForCausalLM`
|
||||
- Phi
|
||||
- :code:`microsoft/phi-1_5`, :code:`microsoft/phi-2`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`Phi3ForCausalLM`
|
||||
- Phi-3
|
||||
- :code:`microsoft/Phi-3-mini-4k-instruct`, :code:`microsoft/Phi-3-mini-128k-instruct`, etc.
|
||||
|
||||
@ -165,6 +165,11 @@ def tinyllama_lora_files():
|
||||
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def phi2_lora_files():
|
||||
return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def long_context_lora_files_16k_1():
|
||||
return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_1")
|
||||
|
||||
67
tests/lora/test_phi.py
Normal file
67
tests/lora/test_phi.py
Normal file
@ -0,0 +1,67 @@
|
||||
import vllm
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
MODEL_PATH = "microsoft/phi-2"
|
||||
|
||||
PROMPT_TEMPLATE = "### Instruct: {sql_prompt}\n\n### Context: {context}\n\n### Output:" # noqa: E501
|
||||
|
||||
|
||||
def do_sample(llm, lora_path: str, lora_id: int) -> str:
|
||||
prompts = [
|
||||
PROMPT_TEMPLATE.format(
|
||||
sql_prompt=
|
||||
"Which catalog publisher has published the most catalogs?",
|
||||
context="CREATE TABLE catalogs (catalog_publisher VARCHAR);"),
|
||||
PROMPT_TEMPLATE.format(
|
||||
sql_prompt=
|
||||
"Which trip started from the station with the largest dock count? Give me the trip id.", # noqa: E501
|
||||
context=
|
||||
"CREATE TABLE trip (id VARCHAR, start_station_id VARCHAR); CREATE TABLE station (id VARCHAR, dock_count VARCHAR);" # noqa: E501
|
||||
),
|
||||
PROMPT_TEMPLATE.format(
|
||||
sql_prompt=
|
||||
"How many marine species are found in the Southern Ocean?", # noqa: E501
|
||||
context=
|
||||
"CREATE TABLE marine_species (name VARCHAR(50), common_name VARCHAR(50), location VARCHAR(50));" # noqa: E501
|
||||
),
|
||||
]
|
||||
sampling_params = vllm.SamplingParams(temperature=0,
|
||||
max_tokens=64,
|
||||
stop="### End")
|
||||
outputs = llm.generate(
|
||||
prompts,
|
||||
sampling_params,
|
||||
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
|
||||
if lora_id else None,
|
||||
)
|
||||
# Print the outputs.
|
||||
generated_texts = []
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text.strip()
|
||||
generated_texts.append(generated_text)
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
return generated_texts
|
||||
|
||||
|
||||
def test_phi2_lora(phi2_lora_files):
|
||||
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
|
||||
# Otherwise, the lora-test will fail due to CUDA OOM.
|
||||
llm = vllm.LLM(MODEL_PATH,
|
||||
max_model_len=1024,
|
||||
enable_lora=True,
|
||||
max_loras=2,
|
||||
enforce_eager=True)
|
||||
|
||||
expected_lora_output = [
|
||||
"SELECT catalog_publisher, COUNT(*) as num_catalogs FROM catalogs GROUP BY catalog_publisher ORDER BY num_catalogs DESC LIMIT 1;", # noqa: E501
|
||||
"SELECT trip.id FROM trip JOIN station ON trip.start_station_id = station.id WHERE station.dock_count = (SELECT MAX(dock_count) FROM station);", # noqa: E501
|
||||
"SELECT COUNT(*) FROM marine_species WHERE location = 'Southern Ocean';", # noqa: E501
|
||||
]
|
||||
|
||||
output1 = do_sample(llm, phi2_lora_files, lora_id=1)
|
||||
for i in range(len(expected_lora_output)):
|
||||
assert output1[i].startswith(expected_lora_output[i])
|
||||
output2 = do_sample(llm, phi2_lora_files, lora_id=2)
|
||||
for i in range(len(expected_lora_output)):
|
||||
assert output2[i].startswith(expected_lora_output[i])
|
||||
@ -42,7 +42,7 @@ from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@ -229,11 +229,32 @@ class PhiModel(nn.Module):
|
||||
|
||||
|
||||
class PhiForCausalLM(nn.Module):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
]
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj",
|
||||
"dense",
|
||||
"fc1",
|
||||
"fc2",
|
||||
]
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
):
|
||||
del lora_config # Unused.
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user