[TPU][V1] Add support for top-logprobs (#17072)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi 2025-05-05 23:20:15 +02:00 committed by GitHub
parent 9765940824
commit 5941e0b7ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 105 additions and 17 deletions

View File

@ -61,3 +61,51 @@ def test_sampler_different(model_name: str):
# to have deterministic results over many tokens, tests the first ~20
# tokens match.
assert output[0].outputs[0].text[:20] == output[1].outputs[0].text[:20]
@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
# TODO TPU will appear busy if we fan-out test params here
@pytest.mark.parametrize("n_prompts", [1])
@pytest.mark.skipif(not current_platform.is_tpu(),
reason="This test needs a TPU")
def test_logprobs(model_name: str, n_prompts: int):
"""
Request top logprobs with different sampling settings and check
that results contains the requested number, ordered ascendingly.
"""
def check_num_logprobs(logprobs, expected_num: int):
for step in logprobs:
prev_logp = 1.0
# order by rank
sorted_step = dict(
sorted(step.items(), key=lambda item: item[1].rank))
# Can contain the sampled token
assert len(step) == expected_num or len(step) == expected_num + 1
# Check results are ordered by prob value
for rankno, (tid, logp) in enumerate(sorted_step.items()):
assert logp.logprob <= prev_logp
prev_logp = logp.logprob
assert logp.rank == rankno + 1
llm = LLM(model_name,
enforce_eager=False,
max_num_seqs=1,
max_model_len=128,
max_num_batched_tokens=128)
prompts = [
"Write a short story about a robot that dreams for the first time."
] * n_prompts
greedy_sampling_params = SamplingParams(temperature=0.0, max_tokens=64,\
logprobs=4)
regular_sampling_params = SamplingParams(temperature=0.4, max_tokens=64,\
logprobs=4)
topkp_sampling_params = SamplingParams(temperature=0.4, max_tokens=64,\
logprobs=4, top_k=12, top_p=0.5)
for sp in [greedy_sampling_params, regular_sampling_params, \
topkp_sampling_params]:
output = llm.generate(prompts, sp)
for o in output:
check_num_logprobs(o.outputs[0].logprobs, 4)

View File

@ -31,8 +31,10 @@ class TPUSupportedSamplingMetadata:
all_greedy: bool = True
# unsupported, you need to return an extra tensor of static size BxV
max_num_logprobs = None
# Whether logprobs are to be gathered in this batch of request. To balance
# out compile time and runtime, a fixed `max_number_logprobs` value is used
# when gathering logprobs, regardless of the values specified in the batch.
logprobs: bool = False
# TODO No penalties for now
no_penalties: bool = True
@ -84,10 +86,12 @@ class TPUSupportedSamplingMetadata:
we want to pre-compile a graph with sampling parameters, even if
they are not strictly needed for greedy decoding.
"""
needs_logprobs = input_batch.max_num_logprobs>0 if \
input_batch.max_num_logprobs else False
# Early return to avoid unnecessary cpu to tpu copy
if (input_batch.all_greedy is True
and generate_params_if_all_greedy is False):
return cls(all_greedy=True)
return cls(all_greedy=True, logprobs=needs_logprobs)
num_reqs = input_batch.num_reqs
@ -115,4 +119,5 @@ class TPUSupportedSamplingMetadata:
top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(
xla_device),
min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(
xla_device))
xla_device),
logprobs=needs_logprobs)

View File

@ -22,27 +22,18 @@ class Sampler(nn.Module):
logits: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata,
) -> SamplerOutput:
# NOTE(woosuk): Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs.
# This is different from the V0 sampler, which uses the logits that
# is used for sampling (after penalties and temperature scaling).
# Use float32 for the logits.
logits = logits.to(torch.float32)
# Sample the next token.
sampled = self.sample(logits, sampling_metadata)
# Use int32 to reduce the tensor size.
sampled = sampled.to(torch.int32)
# These are GPU tensors.
# These are TPU tensors.
sampler_output = SamplerOutput(
# The sampled tokens are expanded to 2D tensor with shape
# [num_requests, 1], where each row represents one generated
# token per request.
sampled_token_ids=sampled.unsqueeze(-1),
logprobs_tensors=None,
)
logprobs_tensors=None)
return sampler_output
def apply_temperature(
@ -50,7 +41,6 @@ class Sampler(nn.Module):
logits: torch.Tensor,
temp: torch.Tensor,
) -> torch.Tensor:
# Use in-place division to avoid creating a new tensor.
return logits.div_(temp.unsqueeze(dim=1))
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:

View File

@ -791,8 +791,18 @@ class TPUModelRunner:
arange)
selected_token_ids = self.sample_from_logits(logits,
tpu_sampling_metadata)
# NOTE (NickLucche) Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs. We can't enforce it due
# to recompilations outside torch.compiled code, so just make sure
# `sample_from_logits` does not modify the logits in-place.
logprobs = self.gather_logprobs(logits, selected_token_ids) \
if tpu_sampling_metadata.logprobs else None
# Remove padding on cpu and keep dynamic op outside of xla graph.
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
logprobs_lists = logprobs.tolists() \
if tpu_sampling_metadata.logprobs else None
# Update the cache state concurrently. Code above will not block until
# we use `selected_token_ids`. Add mark_step if post-processing changes
@ -862,7 +872,7 @@ class TPUModelRunner:
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=valid_sampled_token_ids,
spec_token_ids=None,
logprobs=None,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
)
@ -1121,6 +1131,22 @@ class TPUModelRunner:
logger.info("Compilation finished in %.2f [secs].", end - start)
self._update_num_xla_graphs("sample_from_logits")
def _precompile_gather_logprobs(self) -> None:
logger.info("Compiling gather_logprobs with different input shapes.")
start = time.perf_counter()
for num_reqs in self.num_reqs_paddings:
dummy_logits = torch.zeros((num_reqs, self.vocab_size),
device=self.device,
dtype=self._hidden_states_dtype)
dummy_tokens = torch.zeros((num_reqs, 1),
dtype=torch.int64).to(self.device)
self.gather_logprobs(dummy_logits, dummy_tokens)
logger.info(" -- num_seqs: %d", num_reqs)
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in %.2f [secs].", end - start)
self._update_num_xla_graphs("gather_logprobs")
def capture_model(self) -> None:
"""
Precompile all the subgraphs with possible input shapes.
@ -1131,6 +1157,7 @@ class TPUModelRunner:
self._precompile_compute_logits()
self._precompile_structured_decoding()
self._precompile_sample_from_logits()
self._precompile_gather_logprobs()
def profile_run(
self,
@ -1254,6 +1281,10 @@ class TPUModelRunner:
def sample_from_logits(
self, logits: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata) -> torch.Tensor:
"""
Sample with xla-friendly function. This function is to be traced
separately from `forward` for lighter compilation overhead.
"""
if sampling_metadata.all_greedy:
out_tokens = torch.argmax(logits, dim=-1, keepdim=True)
else:
@ -1261,6 +1292,20 @@ class TPUModelRunner:
sampling_metadata).sampled_token_ids
return out_tokens
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def gather_logprobs(self, logits: torch.Tensor,
sampled_tokens: torch.Tensor) -> LogprobsTensors:
"""
Gather the top_logprobs with corresponding tokens. Use a fixed number
of logprobs as an alternative to having multiple pre-compiled graphs.
Select the number of logprobs actually demanded by each request on CPU.
"""
logprobs = self.sampler.compute_logprobs(logits)
return self.sampler.gather_logprobs(
logprobs,
self.model_config.max_logprobs,
token_ids=sampled_tokens.squeeze(-1))
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def structured_decode(self, require_struct_decoding: torch.Tensor,
grammar_bitmask: torch.Tensor, logits: torch.Tensor,