[TPU][V1] Fix Sampler recompilation (#15309)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi 2025-03-25 21:43:54 +01:00 committed by GitHub
parent e977c11111
commit a0dd7dcd49
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 83 additions and 126 deletions

View File

@ -5,7 +5,18 @@ from typing import Optional
import torch
import torch_xla.core.xla_model as xm
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_input_batch import InputBatch
DEFAULT_SAMPLING_PARAMS = dict(
temperature=-1.0,
min_p=0.0,
# strictly disabled for now
# top_k=-1,
# top_p=0.0,
# frequency_penalties=0.0,
# presence_penalties=0.0,
# repetition_penalties=0.0,
)
@dataclass
@ -20,14 +31,8 @@ class TPUSupportedSamplingMetadata:
top_k: torch.Tensor = None
top_p: torch.Tensor = None
# XLA-unfriendly control flow in Sampler
all_greedy: bool = False
all_random: bool = False
# Greedy sampling flag for compiling single xla graph.
do_argmax: torch.Tensor = None
# speculation not supported
spec_token_ids = None
all_greedy: torch.Tensor = None
# Generator not supported by xla
generators: dict[int,
@ -54,106 +59,68 @@ class TPUSupportedSamplingMetadata:
bad_words_token_ids = None
indices_do_sample: torch.Tensor = None
def __post_init__(self):
temp = self.temperature
if self.indices_do_sample is None:
self.indices_do_sample = torch.zeros(temp.shape[0],
device=temp.device,
dtype=torch.int32)
if self.do_argmax is None:
self.do_argmax = torch.tensor(0,
dtype=torch.bool,
device=temp.device)
@classmethod
def from_sampling_metadata(
cls, metadata: SamplingMetadata,
padded_do_sample_indices: torch.Tensor, num_do_sample: int,
device: torch.device) -> "TPUSupportedSamplingMetadata":
def from_input_batch(
cls, input_batch: InputBatch,
indices_do_sample: torch.Tensor) -> "TPUSupportedSamplingMetadata":
"""
Create an XLA-frienly SamplingMetadata structure. Do so by first
instantiating an object with fixed-sized tensors and then writing the
values in input `metadata`. Do that only for non-None values so that
recompilation is not triggered for optional values (None/torch.Tensor).
In order to handle different sizes for the params that range from 1 up
to `max_num_seqs`, pad tensors to the closest pre-compiled shape.
Same thing for `padded_do_sample_indices`, which contains the indices
to be fed to the Sampler, padded to the closest pre-compiled shape.
Copy sampling tensors slices from `input_batch` to on device tensors.
Eg. pad to 4 temperature: [0.7, 0.2]=>[0.7, 0.2, 0.0, 0.0]
do_sample_indices: [4, 10]=>padded_do_sample_indices: [4, 10, 0, 0]
`InputBatch._make_sampling_metadata` causes recompilation on XLA as it
slices dynamic shapes on device tensors. This impl moves the dynamic
ops to CPU and produces tensors of fixed `padded_num_reqs` size. It
also reuses the on-device persistent tensors managed in `input_batch`
to reduce waste.
`indices_do_sample` contains the indices to be fed to the Sampler,
normally one per request, here padded to the closest pre-compiled shape
We expect sampling params tensors to be padded to the same fixed shape.
Eg. 3 requests, tensors padded to 4
temperature: [0.7, 0.2, 0.9]=>[0.7, 0.2, 0.9, 0.0]
sample indices: [4, 10, 11]=>indices_do_sample: [4, 10, 11, 0]
"""
metadata = cls._validate_sampling_metadata(metadata)
# NOTE we have to initialize default tensor-based params first and
# skip None values altogether to produce the same xla graph.
num_samples = len(padded_do_sample_indices)
do_argmax = torch.tensor(metadata.all_greedy,
dtype=torch.bool,
device=device)
new_metadata = cls.get_default_sampling_params(num_samples, device,
indices_do_sample=\
padded_do_sample_indices,
do_argmax=do_argmax
)
supported_params = \
TPUSupportedSamplingMetadata._get_default_params_values()
# Copy input non-None values into `new_metadata` fixed-sized tensors.
for p_name in supported_params:
old_val = getattr(metadata, p_name)
new_val = getattr(new_metadata, p_name)
if isinstance(old_val, torch.Tensor):
new_val[:num_do_sample] = old_val
setattr(new_metadata, p_name, new_val)
num_reqs = input_batch.num_reqs
padded_num_reqs = len(indices_do_sample)
def copy_slice(cpu_tensor: torch.Tensor, tpu_tensor: torch.Tensor,
fill_val) -> torch.Tensor:
# Copy slice from CPU to corresponding TPU pre-allocated tensor.
# Pad value is the default one.
cpu_tensor[num_reqs:padded_num_reqs] = fill_val
tpu_tensor[:padded_num_reqs] = cpu_tensor[:padded_num_reqs]
# NOTE NickLucche The sync CPU-TPU graph we produce here must be
# consistent. We can't have flags to skip copies or we'll end up
# recompiling.
copy_slice(input_batch.temperature_cpu_tensor, input_batch.temperature,
DEFAULT_SAMPLING_PARAMS["temperature"])
# TODO Temporarily disabled until sampling options are enabled
# copy_slice(input_batch.top_p_cpu_tensor, input_batch.top_p)
# copy_slice(input_batch.top_k_cpu_tensor, input_batch.top_k)
copy_slice(input_batch.min_p_cpu_tensor, input_batch.min_p,
DEFAULT_SAMPLING_PARAMS["min_p"])
# copy_slice(input_batch.frequency_penalties_cpu_tensor,
# input_batch.frequency_penalties)
# copy_slice(input_batch.presence_penalties_cpu_tensor,
# input_batch.presence_penalties)
# copy_slice(input_batch.repetition_penalties_cpu_tensor,
# input_batch.repetition_penalties)
xm.mark_step()
xm.wait_device_ops()
return new_metadata
@classmethod
def get_default_sampling_params(
cls,
num_samples: int,
device: torch.device,
indices_do_sample=None,
do_argmax=None) -> "TPUSupportedSamplingMetadata":
# As sampling happens on a single traced graph, options
# are "disabled" by having them evaluate to an Identity op.
# Note that initialization is dependent on num_samples.
sampling_metadata_disable_value = \
TPUSupportedSamplingMetadata._get_default_params_values()
init_kwargs = dict()
for p_name, (default_val,
dtype) in sampling_metadata_disable_value.items():
default_tensor = torch.full((num_samples, ),
default_val,
dtype=dtype,
device=device)
init_kwargs[p_name] = default_tensor
return cls(**init_kwargs,
indices_do_sample=indices_do_sample,
do_argmax=do_argmax)
@staticmethod
def _validate_sampling_metadata(
sampling_metadata: SamplingMetadata) -> SamplingMetadata:
if sampling_metadata.all_greedy:
# Set to None since #13587. Make sure default isn't overruled.
assert sampling_metadata.temperature is None
return sampling_metadata
@staticmethod
def _get_default_params_values():
return dict(
# Since #13587 greedy sampling requires branching off which leads
# to separate graphs. We set temp to noop and handle argmax here.
temperature=(1.0, torch.float32),
min_p=(0.0, torch.float32),
# strictly disabled for now
# top_k=(-1, torch.int32),
# top_p=(0.0, torch.float32),
# frequency_penalties=(0.0, torch.float32),
# presence_penalties=(0.0, torch.float32),
# repetition_penalties=(0.0, torch.float32),
)
# Slice persistent device tensors to a fixed pre-compiled padded shape.
return cls(
temperature=input_batch.temperature[:padded_num_reqs],
# Scalar tensor for xla-friendly tracing.
all_greedy=torch.tensor(input_batch.all_greedy,
dtype=torch.bool,
device=input_batch.device),
# TODO enable more and avoid returning None values
top_p=None, # input_batch.top_p[:padded_num_reqs],
top_k=None, # input_batch.top_k[:padded_num_reqs],
min_p=input_batch.min_p[:padded_num_reqs],
generators=input_batch.generators,
indices_do_sample=indices_do_sample)

View File

@ -279,9 +279,6 @@ class TPUModelRunner:
req_data.num_computed_tokens)
self.input_batch.block_table.append_row(req_data.new_block_ids,
req_index)
# Check if the batch has changed. If not, we can skip copying the
# sampling metadata from CPU to GPU.
batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
@ -300,9 +297,6 @@ class TPUModelRunner:
if removed_req_indices:
self.input_batch.condense(removed_req_indices)
# TODO This slices tensors to copy to device, triggering recompilation.
if batch_changed:
self.input_batch.refresh_sampling_metadata()
return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0
def get_model(self) -> nn.Module:
@ -597,14 +591,12 @@ class TPUModelRunner:
# then the embedding layer is not included in the CUDA graph.
input_ids = self.input_ids
inputs_embeds = None
sampling_metadata = self.input_batch.sampling_metadata
num_reqs = self.input_batch.num_reqs
# NOTE (NickLucche) here we sync with TPU: if there's any shape
# mismatch in pre-processing, it will trigger a small recompilation
# of the code thus far. Forward graph remains untouched.
# NOTE (NickLucche) here we sync with TPU: sampling params tensors
# are copied to device in chunks of pre-compiled padded shape to
# avoid recompilations.
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
from_sampling_metadata(sampling_metadata, logits_indices,
num_reqs, self.device)
from_input_batch(self.input_batch, logits_indices)
# Run the decoder
with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model(
@ -797,21 +789,19 @@ class TPUModelRunner:
device=device,
dtype=torch.bfloat16)
while True:
# Default metadata is an all_greedy setup. But since the
# `do_argmax` flag is a tensor, we still compile the full graph
meta = self.input_batch.sampling_metadata
indices = torch.zeros(
num_reqs_to_sample,
dtype=torch.int32,
device=device,
)
xm.mark_step()
sampling_meta = TPUSupportedSamplingMetadata.\
from_sampling_metadata(meta, indices,
num_reqs_to_sample, device)
from_input_batch(self.input_batch, indices)
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
num_reqs_to_sample)
self.model.sample_from_hidden(dummy_hidden, sampling_meta)
xm.mark_step()
out = self.model.sample_from_hidden(dummy_hidden,
sampling_meta)
out = out.cpu()
if num_reqs_to_sample >= self.max_num_reqs:
break
num_reqs_to_sample *= 2
@ -910,6 +900,7 @@ class ModelWrapperV1(nn.Module):
return hidden_states
# @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def sample_from_hidden(
self,
hidden_states: torch.Tensor,
@ -923,10 +914,9 @@ class ModelWrapperV1(nn.Module):
sample_hidden_states = \
hidden_states[sampling_metadata.indices_do_sample]
logits = self.compute_logits(sample_hidden_states)
# Greedy sampling can't be run without branching the graph on Sampler.
# Therefore do_argmax/all_greedy is checked here in a xla-friendly way.
# NOTE do_argmax is a scalar, this is just an optimized if/else.
out_tokens = torch.where(sampling_metadata.do_argmax,
# Optimized greedy sampling branch, tracing both paths in a single pass
# NOTE all_greedy is a scalar, this is just an optimized if/else.
out_tokens = torch.where(sampling_metadata.all_greedy,
torch.argmax(logits, dim=-1, keepdim=True),
self.sample(logits, sampling_metadata)\
.sampled_token_ids)