mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:24:54 +08:00
[TPU][V1] Add support for top-logprobs (#17072)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
9765940824
commit
5941e0b7ea
@ -61,3 +61,51 @@ def test_sampler_different(model_name: str):
|
||||
# to have deterministic results over many tokens, tests the first ~20
|
||||
# tokens match.
|
||||
assert output[0].outputs[0].text[:20] == output[1].outputs[0].text[:20]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
|
||||
# TODO TPU will appear busy if we fan-out test params here
|
||||
@pytest.mark.parametrize("n_prompts", [1])
|
||||
@pytest.mark.skipif(not current_platform.is_tpu(),
|
||||
reason="This test needs a TPU")
|
||||
def test_logprobs(model_name: str, n_prompts: int):
|
||||
"""
|
||||
Request top logprobs with different sampling settings and check
|
||||
that results contains the requested number, ordered ascendingly.
|
||||
"""
|
||||
|
||||
def check_num_logprobs(logprobs, expected_num: int):
|
||||
for step in logprobs:
|
||||
prev_logp = 1.0
|
||||
# order by rank
|
||||
sorted_step = dict(
|
||||
sorted(step.items(), key=lambda item: item[1].rank))
|
||||
|
||||
# Can contain the sampled token
|
||||
assert len(step) == expected_num or len(step) == expected_num + 1
|
||||
# Check results are ordered by prob value
|
||||
for rankno, (tid, logp) in enumerate(sorted_step.items()):
|
||||
assert logp.logprob <= prev_logp
|
||||
prev_logp = logp.logprob
|
||||
assert logp.rank == rankno + 1
|
||||
|
||||
llm = LLM(model_name,
|
||||
enforce_eager=False,
|
||||
max_num_seqs=1,
|
||||
max_model_len=128,
|
||||
max_num_batched_tokens=128)
|
||||
prompts = [
|
||||
"Write a short story about a robot that dreams for the first time."
|
||||
] * n_prompts
|
||||
greedy_sampling_params = SamplingParams(temperature=0.0, max_tokens=64,\
|
||||
logprobs=4)
|
||||
regular_sampling_params = SamplingParams(temperature=0.4, max_tokens=64,\
|
||||
logprobs=4)
|
||||
topkp_sampling_params = SamplingParams(temperature=0.4, max_tokens=64,\
|
||||
logprobs=4, top_k=12, top_p=0.5)
|
||||
|
||||
for sp in [greedy_sampling_params, regular_sampling_params, \
|
||||
topkp_sampling_params]:
|
||||
output = llm.generate(prompts, sp)
|
||||
for o in output:
|
||||
check_num_logprobs(o.outputs[0].logprobs, 4)
|
||||
|
||||
@ -31,8 +31,10 @@ class TPUSupportedSamplingMetadata:
|
||||
|
||||
all_greedy: bool = True
|
||||
|
||||
# unsupported, you need to return an extra tensor of static size BxV
|
||||
max_num_logprobs = None
|
||||
# Whether logprobs are to be gathered in this batch of request. To balance
|
||||
# out compile time and runtime, a fixed `max_number_logprobs` value is used
|
||||
# when gathering logprobs, regardless of the values specified in the batch.
|
||||
logprobs: bool = False
|
||||
|
||||
# TODO No penalties for now
|
||||
no_penalties: bool = True
|
||||
@ -84,10 +86,12 @@ class TPUSupportedSamplingMetadata:
|
||||
we want to pre-compile a graph with sampling parameters, even if
|
||||
they are not strictly needed for greedy decoding.
|
||||
"""
|
||||
needs_logprobs = input_batch.max_num_logprobs>0 if \
|
||||
input_batch.max_num_logprobs else False
|
||||
# Early return to avoid unnecessary cpu to tpu copy
|
||||
if (input_batch.all_greedy is True
|
||||
and generate_params_if_all_greedy is False):
|
||||
return cls(all_greedy=True)
|
||||
return cls(all_greedy=True, logprobs=needs_logprobs)
|
||||
|
||||
num_reqs = input_batch.num_reqs
|
||||
|
||||
@ -115,4 +119,5 @@ class TPUSupportedSamplingMetadata:
|
||||
top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(
|
||||
xla_device),
|
||||
min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(
|
||||
xla_device))
|
||||
xla_device),
|
||||
logprobs=needs_logprobs)
|
||||
|
||||
@ -22,27 +22,18 @@ class Sampler(nn.Module):
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: TPUSupportedSamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
# NOTE(woosuk): Use the original logits (before any penalties or
|
||||
# temperature scaling) for the top-k logprobs.
|
||||
# This is different from the V0 sampler, which uses the logits that
|
||||
# is used for sampling (after penalties and temperature scaling).
|
||||
|
||||
# Use float32 for the logits.
|
||||
logits = logits.to(torch.float32)
|
||||
# Sample the next token.
|
||||
sampled = self.sample(logits, sampling_metadata)
|
||||
|
||||
# Use int32 to reduce the tensor size.
|
||||
sampled = sampled.to(torch.int32)
|
||||
|
||||
# These are GPU tensors.
|
||||
# These are TPU tensors.
|
||||
sampler_output = SamplerOutput(
|
||||
# The sampled tokens are expanded to 2D tensor with shape
|
||||
# [num_requests, 1], where each row represents one generated
|
||||
# token per request.
|
||||
sampled_token_ids=sampled.unsqueeze(-1),
|
||||
logprobs_tensors=None,
|
||||
)
|
||||
logprobs_tensors=None)
|
||||
return sampler_output
|
||||
|
||||
def apply_temperature(
|
||||
@ -50,7 +41,6 @@ class Sampler(nn.Module):
|
||||
logits: torch.Tensor,
|
||||
temp: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Use in-place division to avoid creating a new tensor.
|
||||
return logits.div_(temp.unsqueeze(dim=1))
|
||||
|
||||
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@ -791,8 +791,18 @@ class TPUModelRunner:
|
||||
arange)
|
||||
selected_token_ids = self.sample_from_logits(logits,
|
||||
tpu_sampling_metadata)
|
||||
|
||||
# NOTE (NickLucche) Use the original logits (before any penalties or
|
||||
# temperature scaling) for the top-k logprobs. We can't enforce it due
|
||||
# to recompilations outside torch.compiled code, so just make sure
|
||||
# `sample_from_logits` does not modify the logits in-place.
|
||||
logprobs = self.gather_logprobs(logits, selected_token_ids) \
|
||||
if tpu_sampling_metadata.logprobs else None
|
||||
|
||||
# Remove padding on cpu and keep dynamic op outside of xla graph.
|
||||
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
|
||||
logprobs_lists = logprobs.tolists() \
|
||||
if tpu_sampling_metadata.logprobs else None
|
||||
|
||||
# Update the cache state concurrently. Code above will not block until
|
||||
# we use `selected_token_ids`. Add mark_step if post-processing changes
|
||||
@ -862,7 +872,7 @@ class TPUModelRunner:
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=valid_sampled_token_ids,
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
)
|
||||
|
||||
@ -1121,6 +1131,22 @@ class TPUModelRunner:
|
||||
logger.info("Compilation finished in %.2f [secs].", end - start)
|
||||
self._update_num_xla_graphs("sample_from_logits")
|
||||
|
||||
def _precompile_gather_logprobs(self) -> None:
|
||||
logger.info("Compiling gather_logprobs with different input shapes.")
|
||||
start = time.perf_counter()
|
||||
for num_reqs in self.num_reqs_paddings:
|
||||
dummy_logits = torch.zeros((num_reqs, self.vocab_size),
|
||||
device=self.device,
|
||||
dtype=self._hidden_states_dtype)
|
||||
dummy_tokens = torch.zeros((num_reqs, 1),
|
||||
dtype=torch.int64).to(self.device)
|
||||
self.gather_logprobs(dummy_logits, dummy_tokens)
|
||||
logger.info(" -- num_seqs: %d", num_reqs)
|
||||
xm.wait_device_ops()
|
||||
end = time.perf_counter()
|
||||
logger.info("Compilation finished in %.2f [secs].", end - start)
|
||||
self._update_num_xla_graphs("gather_logprobs")
|
||||
|
||||
def capture_model(self) -> None:
|
||||
"""
|
||||
Precompile all the subgraphs with possible input shapes.
|
||||
@ -1131,6 +1157,7 @@ class TPUModelRunner:
|
||||
self._precompile_compute_logits()
|
||||
self._precompile_structured_decoding()
|
||||
self._precompile_sample_from_logits()
|
||||
self._precompile_gather_logprobs()
|
||||
|
||||
def profile_run(
|
||||
self,
|
||||
@ -1254,6 +1281,10 @@ class TPUModelRunner:
|
||||
def sample_from_logits(
|
||||
self, logits: torch.Tensor,
|
||||
sampling_metadata: TPUSupportedSamplingMetadata) -> torch.Tensor:
|
||||
"""
|
||||
Sample with xla-friendly function. This function is to be traced
|
||||
separately from `forward` for lighter compilation overhead.
|
||||
"""
|
||||
if sampling_metadata.all_greedy:
|
||||
out_tokens = torch.argmax(logits, dim=-1, keepdim=True)
|
||||
else:
|
||||
@ -1261,6 +1292,20 @@ class TPUModelRunner:
|
||||
sampling_metadata).sampled_token_ids
|
||||
return out_tokens
|
||||
|
||||
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
|
||||
def gather_logprobs(self, logits: torch.Tensor,
|
||||
sampled_tokens: torch.Tensor) -> LogprobsTensors:
|
||||
"""
|
||||
Gather the top_logprobs with corresponding tokens. Use a fixed number
|
||||
of logprobs as an alternative to having multiple pre-compiled graphs.
|
||||
Select the number of logprobs actually demanded by each request on CPU.
|
||||
"""
|
||||
logprobs = self.sampler.compute_logprobs(logits)
|
||||
return self.sampler.gather_logprobs(
|
||||
logprobs,
|
||||
self.model_config.max_logprobs,
|
||||
token_ids=sampled_tokens.squeeze(-1))
|
||||
|
||||
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
|
||||
def structured_decode(self, require_struct_decoding: torch.Tensor,
|
||||
grammar_bitmask: torch.Tensor, logits: torch.Tensor,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user