mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 16:36:02 +08:00
[TPU][V1] Enable Top-P (#16843)
Signed-off-by: NickLucche <nlucches@redhat.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
986537f1c3
commit
fa3bba2a53
@ -42,7 +42,7 @@ def test_sampler_different(model_name: str):
|
|||||||
sampling_params = SamplingParams(temperature=0.3, seed=42)
|
sampling_params = SamplingParams(temperature=0.3, seed=42)
|
||||||
output2 = llm.generate(prompts, sampling_params)
|
output2 = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
# Batch-case with TopK
|
# Batch-case with TopK/P
|
||||||
for B in [4, 16]:
|
for B in [4, 16]:
|
||||||
p = prompts * B
|
p = prompts * B
|
||||||
sampling_params = [
|
sampling_params = [
|
||||||
@ -51,9 +51,10 @@ def test_sampler_different(model_name: str):
|
|||||||
min_p=0.8,
|
min_p=0.8,
|
||||||
max_tokens=64,
|
max_tokens=64,
|
||||||
# Vary number of ks
|
# Vary number of ks
|
||||||
top_k=random.randint(4, 12)) for _ in range(B)
|
top_k=random.randint(4, 12),
|
||||||
|
top_p=random.random()) for _ in range(B)
|
||||||
]
|
]
|
||||||
# Make sure first two reqs have the same K
|
# Make sure first two reqs have the same K/P
|
||||||
sampling_params[0] = sampling_params[1]
|
sampling_params[0] = sampling_params[1]
|
||||||
output = llm.generate(p, sampling_params)
|
output = llm.generate(p, sampling_params)
|
||||||
assert output[0].outputs[0].text == output[1].outputs[0].text
|
assert output[0].outputs[0].text == output[1].outputs[0].text
|
||||||
|
|||||||
@ -11,7 +11,7 @@ DEFAULT_SAMPLING_PARAMS = dict(
|
|||||||
min_p=0.0,
|
min_p=0.0,
|
||||||
# strictly disabled for now
|
# strictly disabled for now
|
||||||
top_k=0,
|
top_k=0,
|
||||||
# top_p=0.0,
|
top_p=1.0,
|
||||||
# frequency_penalties=0.0,
|
# frequency_penalties=0.0,
|
||||||
# presence_penalties=0.0,
|
# presence_penalties=0.0,
|
||||||
# repetition_penalties=0.0,
|
# repetition_penalties=0.0,
|
||||||
@ -26,11 +26,9 @@ class TPUSupportedSamplingMetadata:
|
|||||||
temperature: torch.Tensor = None
|
temperature: torch.Tensor = None
|
||||||
|
|
||||||
min_p: torch.Tensor = None
|
min_p: torch.Tensor = None
|
||||||
# Still too slow on forward_native!
|
|
||||||
top_k: torch.Tensor = None
|
top_k: torch.Tensor = None
|
||||||
top_p: torch.Tensor = None
|
top_p: torch.Tensor = None
|
||||||
|
|
||||||
# Greedy sampling flag for compiling single xla graph.
|
|
||||||
all_greedy: bool = True
|
all_greedy: bool = True
|
||||||
|
|
||||||
# unsupported, you need to return an extra tensor of static size BxV
|
# unsupported, you need to return an extra tensor of static size BxV
|
||||||
@ -103,9 +101,8 @@ class TPUSupportedSamplingMetadata:
|
|||||||
DEFAULT_SAMPLING_PARAMS["min_p"])
|
DEFAULT_SAMPLING_PARAMS["min_p"])
|
||||||
fill_slice(input_batch.top_k_cpu_tensor,
|
fill_slice(input_batch.top_k_cpu_tensor,
|
||||||
DEFAULT_SAMPLING_PARAMS["top_k"])
|
DEFAULT_SAMPLING_PARAMS["top_k"])
|
||||||
# TODO Temporarily disabled until sampling options are enabled
|
fill_slice(input_batch.top_p_cpu_tensor,
|
||||||
# fill_slice(input_batch.top_p_cpu_tensor,
|
DEFAULT_SAMPLING_PARAMS["top_p"])
|
||||||
# DEFAULT_SAMPLING_PARAMS["top_p"])
|
|
||||||
|
|
||||||
# Slice persistent device tensors to a fixed pre-compiled padded shape.
|
# Slice persistent device tensors to a fixed pre-compiled padded shape.
|
||||||
return cls(
|
return cls(
|
||||||
@ -113,7 +110,8 @@ class TPUSupportedSamplingMetadata:
|
|||||||
to(xla_device),
|
to(xla_device),
|
||||||
all_greedy=input_batch.all_greedy,
|
all_greedy=input_batch.all_greedy,
|
||||||
# TODO enable more and avoid returning None values
|
# TODO enable more and avoid returning None values
|
||||||
top_p=None, # input_batch.top_p[:padded_num_reqs],
|
top_p=input_batch.top_p_cpu_tensor[:padded_num_reqs].to(
|
||||||
|
xla_device),
|
||||||
top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(
|
top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(
|
||||||
xla_device),
|
xla_device),
|
||||||
min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(
|
min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user