mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-20 19:35:01 +08:00
[TPU][V1] Add support for top-logprobs (#17072)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
9765940824
commit
5941e0b7ea
@ -61,3 +61,51 @@ def test_sampler_different(model_name: str):
|
|||||||
# to have deterministic results over many tokens, tests the first ~20
|
# to have deterministic results over many tokens, tests the first ~20
|
||||||
# tokens match.
|
# tokens match.
|
||||||
assert output[0].outputs[0].text[:20] == output[1].outputs[0].text[:20]
|
assert output[0].outputs[0].text[:20] == output[1].outputs[0].text[:20]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
|
||||||
|
# TODO TPU will appear busy if we fan-out test params here
|
||||||
|
@pytest.mark.parametrize("n_prompts", [1])
|
||||||
|
@pytest.mark.skipif(not current_platform.is_tpu(),
|
||||||
|
reason="This test needs a TPU")
|
||||||
|
def test_logprobs(model_name: str, n_prompts: int):
|
||||||
|
"""
|
||||||
|
Request top logprobs with different sampling settings and check
|
||||||
|
that results contains the requested number, ordered ascendingly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def check_num_logprobs(logprobs, expected_num: int):
|
||||||
|
for step in logprobs:
|
||||||
|
prev_logp = 1.0
|
||||||
|
# order by rank
|
||||||
|
sorted_step = dict(
|
||||||
|
sorted(step.items(), key=lambda item: item[1].rank))
|
||||||
|
|
||||||
|
# Can contain the sampled token
|
||||||
|
assert len(step) == expected_num or len(step) == expected_num + 1
|
||||||
|
# Check results are ordered by prob value
|
||||||
|
for rankno, (tid, logp) in enumerate(sorted_step.items()):
|
||||||
|
assert logp.logprob <= prev_logp
|
||||||
|
prev_logp = logp.logprob
|
||||||
|
assert logp.rank == rankno + 1
|
||||||
|
|
||||||
|
llm = LLM(model_name,
|
||||||
|
enforce_eager=False,
|
||||||
|
max_num_seqs=1,
|
||||||
|
max_model_len=128,
|
||||||
|
max_num_batched_tokens=128)
|
||||||
|
prompts = [
|
||||||
|
"Write a short story about a robot that dreams for the first time."
|
||||||
|
] * n_prompts
|
||||||
|
greedy_sampling_params = SamplingParams(temperature=0.0, max_tokens=64,\
|
||||||
|
logprobs=4)
|
||||||
|
regular_sampling_params = SamplingParams(temperature=0.4, max_tokens=64,\
|
||||||
|
logprobs=4)
|
||||||
|
topkp_sampling_params = SamplingParams(temperature=0.4, max_tokens=64,\
|
||||||
|
logprobs=4, top_k=12, top_p=0.5)
|
||||||
|
|
||||||
|
for sp in [greedy_sampling_params, regular_sampling_params, \
|
||||||
|
topkp_sampling_params]:
|
||||||
|
output = llm.generate(prompts, sp)
|
||||||
|
for o in output:
|
||||||
|
check_num_logprobs(o.outputs[0].logprobs, 4)
|
||||||
|
|||||||
@ -31,8 +31,10 @@ class TPUSupportedSamplingMetadata:
|
|||||||
|
|
||||||
all_greedy: bool = True
|
all_greedy: bool = True
|
||||||
|
|
||||||
# unsupported, you need to return an extra tensor of static size BxV
|
# Whether logprobs are to be gathered in this batch of request. To balance
|
||||||
max_num_logprobs = None
|
# out compile time and runtime, a fixed `max_number_logprobs` value is used
|
||||||
|
# when gathering logprobs, regardless of the values specified in the batch.
|
||||||
|
logprobs: bool = False
|
||||||
|
|
||||||
# TODO No penalties for now
|
# TODO No penalties for now
|
||||||
no_penalties: bool = True
|
no_penalties: bool = True
|
||||||
@ -84,10 +86,12 @@ class TPUSupportedSamplingMetadata:
|
|||||||
we want to pre-compile a graph with sampling parameters, even if
|
we want to pre-compile a graph with sampling parameters, even if
|
||||||
they are not strictly needed for greedy decoding.
|
they are not strictly needed for greedy decoding.
|
||||||
"""
|
"""
|
||||||
|
needs_logprobs = input_batch.max_num_logprobs>0 if \
|
||||||
|
input_batch.max_num_logprobs else False
|
||||||
# Early return to avoid unnecessary cpu to tpu copy
|
# Early return to avoid unnecessary cpu to tpu copy
|
||||||
if (input_batch.all_greedy is True
|
if (input_batch.all_greedy is True
|
||||||
and generate_params_if_all_greedy is False):
|
and generate_params_if_all_greedy is False):
|
||||||
return cls(all_greedy=True)
|
return cls(all_greedy=True, logprobs=needs_logprobs)
|
||||||
|
|
||||||
num_reqs = input_batch.num_reqs
|
num_reqs = input_batch.num_reqs
|
||||||
|
|
||||||
@ -115,4 +119,5 @@ class TPUSupportedSamplingMetadata:
|
|||||||
top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(
|
top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(
|
||||||
xla_device),
|
xla_device),
|
||||||
min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(
|
min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(
|
||||||
xla_device))
|
xla_device),
|
||||||
|
logprobs=needs_logprobs)
|
||||||
|
|||||||
@ -22,27 +22,18 @@ class Sampler(nn.Module):
|
|||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: TPUSupportedSamplingMetadata,
|
sampling_metadata: TPUSupportedSamplingMetadata,
|
||||||
) -> SamplerOutput:
|
) -> SamplerOutput:
|
||||||
# NOTE(woosuk): Use the original logits (before any penalties or
|
|
||||||
# temperature scaling) for the top-k logprobs.
|
|
||||||
# This is different from the V0 sampler, which uses the logits that
|
|
||||||
# is used for sampling (after penalties and temperature scaling).
|
|
||||||
|
|
||||||
# Use float32 for the logits.
|
# Use float32 for the logits.
|
||||||
logits = logits.to(torch.float32)
|
logits = logits.to(torch.float32)
|
||||||
# Sample the next token.
|
# Sample the next token.
|
||||||
sampled = self.sample(logits, sampling_metadata)
|
sampled = self.sample(logits, sampling_metadata)
|
||||||
|
|
||||||
# Use int32 to reduce the tensor size.
|
# These are TPU tensors.
|
||||||
sampled = sampled.to(torch.int32)
|
|
||||||
|
|
||||||
# These are GPU tensors.
|
|
||||||
sampler_output = SamplerOutput(
|
sampler_output = SamplerOutput(
|
||||||
# The sampled tokens are expanded to 2D tensor with shape
|
# The sampled tokens are expanded to 2D tensor with shape
|
||||||
# [num_requests, 1], where each row represents one generated
|
# [num_requests, 1], where each row represents one generated
|
||||||
# token per request.
|
# token per request.
|
||||||
sampled_token_ids=sampled.unsqueeze(-1),
|
sampled_token_ids=sampled.unsqueeze(-1),
|
||||||
logprobs_tensors=None,
|
logprobs_tensors=None)
|
||||||
)
|
|
||||||
return sampler_output
|
return sampler_output
|
||||||
|
|
||||||
def apply_temperature(
|
def apply_temperature(
|
||||||
@ -50,7 +41,6 @@ class Sampler(nn.Module):
|
|||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
temp: torch.Tensor,
|
temp: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Use in-place division to avoid creating a new tensor.
|
|
||||||
return logits.div_(temp.unsqueeze(dim=1))
|
return logits.div_(temp.unsqueeze(dim=1))
|
||||||
|
|
||||||
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
|
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
|
||||||
|
|||||||
@ -791,8 +791,18 @@ class TPUModelRunner:
|
|||||||
arange)
|
arange)
|
||||||
selected_token_ids = self.sample_from_logits(logits,
|
selected_token_ids = self.sample_from_logits(logits,
|
||||||
tpu_sampling_metadata)
|
tpu_sampling_metadata)
|
||||||
|
|
||||||
|
# NOTE (NickLucche) Use the original logits (before any penalties or
|
||||||
|
# temperature scaling) for the top-k logprobs. We can't enforce it due
|
||||||
|
# to recompilations outside torch.compiled code, so just make sure
|
||||||
|
# `sample_from_logits` does not modify the logits in-place.
|
||||||
|
logprobs = self.gather_logprobs(logits, selected_token_ids) \
|
||||||
|
if tpu_sampling_metadata.logprobs else None
|
||||||
|
|
||||||
# Remove padding on cpu and keep dynamic op outside of xla graph.
|
# Remove padding on cpu and keep dynamic op outside of xla graph.
|
||||||
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
|
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
|
||||||
|
logprobs_lists = logprobs.tolists() \
|
||||||
|
if tpu_sampling_metadata.logprobs else None
|
||||||
|
|
||||||
# Update the cache state concurrently. Code above will not block until
|
# Update the cache state concurrently. Code above will not block until
|
||||||
# we use `selected_token_ids`. Add mark_step if post-processing changes
|
# we use `selected_token_ids`. Add mark_step if post-processing changes
|
||||||
@ -862,7 +872,7 @@ class TPUModelRunner:
|
|||||||
req_id_to_index=self.input_batch.req_id_to_index,
|
req_id_to_index=self.input_batch.req_id_to_index,
|
||||||
sampled_token_ids=valid_sampled_token_ids,
|
sampled_token_ids=valid_sampled_token_ids,
|
||||||
spec_token_ids=None,
|
spec_token_ids=None,
|
||||||
logprobs=None,
|
logprobs=logprobs_lists,
|
||||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1121,6 +1131,22 @@ class TPUModelRunner:
|
|||||||
logger.info("Compilation finished in %.2f [secs].", end - start)
|
logger.info("Compilation finished in %.2f [secs].", end - start)
|
||||||
self._update_num_xla_graphs("sample_from_logits")
|
self._update_num_xla_graphs("sample_from_logits")
|
||||||
|
|
||||||
|
def _precompile_gather_logprobs(self) -> None:
|
||||||
|
logger.info("Compiling gather_logprobs with different input shapes.")
|
||||||
|
start = time.perf_counter()
|
||||||
|
for num_reqs in self.num_reqs_paddings:
|
||||||
|
dummy_logits = torch.zeros((num_reqs, self.vocab_size),
|
||||||
|
device=self.device,
|
||||||
|
dtype=self._hidden_states_dtype)
|
||||||
|
dummy_tokens = torch.zeros((num_reqs, 1),
|
||||||
|
dtype=torch.int64).to(self.device)
|
||||||
|
self.gather_logprobs(dummy_logits, dummy_tokens)
|
||||||
|
logger.info(" -- num_seqs: %d", num_reqs)
|
||||||
|
xm.wait_device_ops()
|
||||||
|
end = time.perf_counter()
|
||||||
|
logger.info("Compilation finished in %.2f [secs].", end - start)
|
||||||
|
self._update_num_xla_graphs("gather_logprobs")
|
||||||
|
|
||||||
def capture_model(self) -> None:
|
def capture_model(self) -> None:
|
||||||
"""
|
"""
|
||||||
Precompile all the subgraphs with possible input shapes.
|
Precompile all the subgraphs with possible input shapes.
|
||||||
@ -1131,6 +1157,7 @@ class TPUModelRunner:
|
|||||||
self._precompile_compute_logits()
|
self._precompile_compute_logits()
|
||||||
self._precompile_structured_decoding()
|
self._precompile_structured_decoding()
|
||||||
self._precompile_sample_from_logits()
|
self._precompile_sample_from_logits()
|
||||||
|
self._precompile_gather_logprobs()
|
||||||
|
|
||||||
def profile_run(
|
def profile_run(
|
||||||
self,
|
self,
|
||||||
@ -1254,6 +1281,10 @@ class TPUModelRunner:
|
|||||||
def sample_from_logits(
|
def sample_from_logits(
|
||||||
self, logits: torch.Tensor,
|
self, logits: torch.Tensor,
|
||||||
sampling_metadata: TPUSupportedSamplingMetadata) -> torch.Tensor:
|
sampling_metadata: TPUSupportedSamplingMetadata) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Sample with xla-friendly function. This function is to be traced
|
||||||
|
separately from `forward` for lighter compilation overhead.
|
||||||
|
"""
|
||||||
if sampling_metadata.all_greedy:
|
if sampling_metadata.all_greedy:
|
||||||
out_tokens = torch.argmax(logits, dim=-1, keepdim=True)
|
out_tokens = torch.argmax(logits, dim=-1, keepdim=True)
|
||||||
else:
|
else:
|
||||||
@ -1261,6 +1292,20 @@ class TPUModelRunner:
|
|||||||
sampling_metadata).sampled_token_ids
|
sampling_metadata).sampled_token_ids
|
||||||
return out_tokens
|
return out_tokens
|
||||||
|
|
||||||
|
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
|
||||||
|
def gather_logprobs(self, logits: torch.Tensor,
|
||||||
|
sampled_tokens: torch.Tensor) -> LogprobsTensors:
|
||||||
|
"""
|
||||||
|
Gather the top_logprobs with corresponding tokens. Use a fixed number
|
||||||
|
of logprobs as an alternative to having multiple pre-compiled graphs.
|
||||||
|
Select the number of logprobs actually demanded by each request on CPU.
|
||||||
|
"""
|
||||||
|
logprobs = self.sampler.compute_logprobs(logits)
|
||||||
|
return self.sampler.gather_logprobs(
|
||||||
|
logprobs,
|
||||||
|
self.model_config.max_logprobs,
|
||||||
|
token_ids=sampled_tokens.squeeze(-1))
|
||||||
|
|
||||||
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
|
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
|
||||||
def structured_decode(self, require_struct_decoding: torch.Tensor,
|
def structured_decode(self, require_struct_decoding: torch.Tensor,
|
||||||
grammar_bitmask: torch.Tensor, logits: torch.Tensor,
|
grammar_bitmask: torch.Tensor, logits: torch.Tensor,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user