fix LoRA-related examples (#29956)

Signed-off-by: Iceber Gu <caiwei95@hotmail.com>
This commit is contained in:
Iceber Gu 2025-12-04 11:48:30 +08:00 committed by GitHub
parent c493b9d092
commit 33a3d6c798
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 27 deletions

View File

@ -23,31 +23,23 @@ def create_test_prompts(
# this is an example of using quantization without LoRA # this is an example of using quantization without LoRA
( (
"My name is", "My name is",
SamplingParams( SamplingParams(temperature=0.0, logprobs=1, max_tokens=128),
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
),
None, None,
), ),
# the next three examples use quantization with LoRA # the next three examples use quantization with LoRA
( (
"my name is", "my name is",
SamplingParams( SamplingParams(temperature=0.0, logprobs=1, max_tokens=128),
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
),
LoRARequest("lora-test-1", 1, lora_path), LoRARequest("lora-test-1", 1, lora_path),
), ),
( (
"The capital of USA is", "The capital of USA is",
SamplingParams( SamplingParams(temperature=0.0, logprobs=1, max_tokens=128),
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
),
LoRARequest("lora-test-2", 1, lora_path), LoRARequest("lora-test-2", 1, lora_path),
), ),
( (
"The capital of France is", "The capital of France is",
SamplingParams( SamplingParams(temperature=0.0, logprobs=1, max_tokens=128),
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
),
LoRARequest("lora-test-3", 1, lora_path), LoRARequest("lora-test-3", 1, lora_path),
), ),
] ]

View File

@ -27,9 +27,7 @@ def create_test_prompts(
return [ return [
( (
"A robot may not injure a human being", "A robot may not injure a human being",
SamplingParams( SamplingParams(temperature=0.0, logprobs=1, max_tokens=128),
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
),
None, None,
), ),
( (
@ -41,22 +39,12 @@ def create_test_prompts(
), ),
( (
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
SamplingParams( SamplingParams(temperature=0.0, logprobs=1, max_tokens=128),
temperature=0.0,
logprobs=1,
prompt_logprobs=1,
max_tokens=128,
),
LoRARequest("sql-lora", 1, lora_path), LoRARequest("sql-lora", 1, lora_path),
), ),
( (
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
SamplingParams( SamplingParams(temperature=0.0, logprobs=1, max_tokens=128),
temperature=0.0,
logprobs=1,
prompt_logprobs=1,
max_tokens=128,
),
LoRARequest("sql-lora2", 2, lora_path), LoRARequest("sql-lora2", 2, lora_path),
), ),
] ]