mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-21 20:14:34 +08:00
[TPU][V1] Fix Sampler recompilation (#15309)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
e977c11111
commit
a0dd7dcd49
@ -5,7 +5,18 @@ from typing import Optional
|
||||
import torch
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
DEFAULT_SAMPLING_PARAMS = dict(
|
||||
temperature=-1.0,
|
||||
min_p=0.0,
|
||||
# strictly disabled for now
|
||||
# top_k=-1,
|
||||
# top_p=0.0,
|
||||
# frequency_penalties=0.0,
|
||||
# presence_penalties=0.0,
|
||||
# repetition_penalties=0.0,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -20,14 +31,8 @@ class TPUSupportedSamplingMetadata:
|
||||
top_k: torch.Tensor = None
|
||||
top_p: torch.Tensor = None
|
||||
|
||||
# XLA-unfriendly control flow in Sampler
|
||||
all_greedy: bool = False
|
||||
all_random: bool = False
|
||||
# Greedy sampling flag for compiling single xla graph.
|
||||
do_argmax: torch.Tensor = None
|
||||
|
||||
# speculation not supported
|
||||
spec_token_ids = None
|
||||
all_greedy: torch.Tensor = None
|
||||
|
||||
# Generator not supported by xla
|
||||
generators: dict[int,
|
||||
@ -54,106 +59,68 @@ class TPUSupportedSamplingMetadata:
|
||||
bad_words_token_ids = None
|
||||
indices_do_sample: torch.Tensor = None
|
||||
|
||||
def __post_init__(self):
|
||||
temp = self.temperature
|
||||
if self.indices_do_sample is None:
|
||||
self.indices_do_sample = torch.zeros(temp.shape[0],
|
||||
device=temp.device,
|
||||
dtype=torch.int32)
|
||||
if self.do_argmax is None:
|
||||
self.do_argmax = torch.tensor(0,
|
||||
dtype=torch.bool,
|
||||
device=temp.device)
|
||||
|
||||
@classmethod
|
||||
def from_sampling_metadata(
|
||||
cls, metadata: SamplingMetadata,
|
||||
padded_do_sample_indices: torch.Tensor, num_do_sample: int,
|
||||
device: torch.device) -> "TPUSupportedSamplingMetadata":
|
||||
def from_input_batch(
|
||||
cls, input_batch: InputBatch,
|
||||
indices_do_sample: torch.Tensor) -> "TPUSupportedSamplingMetadata":
|
||||
"""
|
||||
Create an XLA-frienly SamplingMetadata structure. Do so by first
|
||||
instantiating an object with fixed-sized tensors and then writing the
|
||||
values in input `metadata`. Do that only for non-None values so that
|
||||
recompilation is not triggered for optional values (None/torch.Tensor).
|
||||
|
||||
In order to handle different sizes for the params that range from 1 up
|
||||
to `max_num_seqs`, pad tensors to the closest pre-compiled shape.
|
||||
Same thing for `padded_do_sample_indices`, which contains the indices
|
||||
to be fed to the Sampler, padded to the closest pre-compiled shape.
|
||||
Copy sampling tensors slices from `input_batch` to on device tensors.
|
||||
|
||||
Eg. pad to 4 temperature: [0.7, 0.2]=>[0.7, 0.2, 0.0, 0.0]
|
||||
do_sample_indices: [4, 10]=>padded_do_sample_indices: [4, 10, 0, 0]
|
||||
`InputBatch._make_sampling_metadata` causes recompilation on XLA as it
|
||||
slices dynamic shapes on device tensors. This impl moves the dynamic
|
||||
ops to CPU and produces tensors of fixed `padded_num_reqs` size. It
|
||||
also reuses the on-device persistent tensors managed in `input_batch`
|
||||
to reduce waste.
|
||||
|
||||
`indices_do_sample` contains the indices to be fed to the Sampler,
|
||||
normally one per request, here padded to the closest pre-compiled shape
|
||||
We expect sampling params tensors to be padded to the same fixed shape.
|
||||
|
||||
Eg. 3 requests, tensors padded to 4
|
||||
temperature: [0.7, 0.2, 0.9]=>[0.7, 0.2, 0.9, 0.0]
|
||||
sample indices: [4, 10, 11]=>indices_do_sample: [4, 10, 11, 0]
|
||||
"""
|
||||
metadata = cls._validate_sampling_metadata(metadata)
|
||||
# NOTE we have to initialize default tensor-based params first and
|
||||
# skip None values altogether to produce the same xla graph.
|
||||
num_samples = len(padded_do_sample_indices)
|
||||
do_argmax = torch.tensor(metadata.all_greedy,
|
||||
dtype=torch.bool,
|
||||
device=device)
|
||||
new_metadata = cls.get_default_sampling_params(num_samples, device,
|
||||
indices_do_sample=\
|
||||
padded_do_sample_indices,
|
||||
do_argmax=do_argmax
|
||||
)
|
||||
supported_params = \
|
||||
TPUSupportedSamplingMetadata._get_default_params_values()
|
||||
# Copy input non-None values into `new_metadata` fixed-sized tensors.
|
||||
for p_name in supported_params:
|
||||
old_val = getattr(metadata, p_name)
|
||||
new_val = getattr(new_metadata, p_name)
|
||||
if isinstance(old_val, torch.Tensor):
|
||||
new_val[:num_do_sample] = old_val
|
||||
setattr(new_metadata, p_name, new_val)
|
||||
num_reqs = input_batch.num_reqs
|
||||
padded_num_reqs = len(indices_do_sample)
|
||||
|
||||
def copy_slice(cpu_tensor: torch.Tensor, tpu_tensor: torch.Tensor,
|
||||
fill_val) -> torch.Tensor:
|
||||
# Copy slice from CPU to corresponding TPU pre-allocated tensor.
|
||||
# Pad value is the default one.
|
||||
cpu_tensor[num_reqs:padded_num_reqs] = fill_val
|
||||
tpu_tensor[:padded_num_reqs] = cpu_tensor[:padded_num_reqs]
|
||||
|
||||
# NOTE NickLucche The sync CPU-TPU graph we produce here must be
|
||||
# consistent. We can't have flags to skip copies or we'll end up
|
||||
# recompiling.
|
||||
copy_slice(input_batch.temperature_cpu_tensor, input_batch.temperature,
|
||||
DEFAULT_SAMPLING_PARAMS["temperature"])
|
||||
# TODO Temporarily disabled until sampling options are enabled
|
||||
# copy_slice(input_batch.top_p_cpu_tensor, input_batch.top_p)
|
||||
# copy_slice(input_batch.top_k_cpu_tensor, input_batch.top_k)
|
||||
copy_slice(input_batch.min_p_cpu_tensor, input_batch.min_p,
|
||||
DEFAULT_SAMPLING_PARAMS["min_p"])
|
||||
|
||||
# copy_slice(input_batch.frequency_penalties_cpu_tensor,
|
||||
# input_batch.frequency_penalties)
|
||||
# copy_slice(input_batch.presence_penalties_cpu_tensor,
|
||||
# input_batch.presence_penalties)
|
||||
# copy_slice(input_batch.repetition_penalties_cpu_tensor,
|
||||
# input_batch.repetition_penalties)
|
||||
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
return new_metadata
|
||||
|
||||
@classmethod
|
||||
def get_default_sampling_params(
|
||||
cls,
|
||||
num_samples: int,
|
||||
device: torch.device,
|
||||
indices_do_sample=None,
|
||||
do_argmax=None) -> "TPUSupportedSamplingMetadata":
|
||||
# As sampling happens on a single traced graph, options
|
||||
# are "disabled" by having them evaluate to an Identity op.
|
||||
# Note that initialization is dependent on num_samples.
|
||||
sampling_metadata_disable_value = \
|
||||
TPUSupportedSamplingMetadata._get_default_params_values()
|
||||
init_kwargs = dict()
|
||||
for p_name, (default_val,
|
||||
dtype) in sampling_metadata_disable_value.items():
|
||||
default_tensor = torch.full((num_samples, ),
|
||||
default_val,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
init_kwargs[p_name] = default_tensor
|
||||
|
||||
return cls(**init_kwargs,
|
||||
indices_do_sample=indices_do_sample,
|
||||
do_argmax=do_argmax)
|
||||
|
||||
@staticmethod
|
||||
def _validate_sampling_metadata(
|
||||
sampling_metadata: SamplingMetadata) -> SamplingMetadata:
|
||||
if sampling_metadata.all_greedy:
|
||||
# Set to None since #13587. Make sure default isn't overruled.
|
||||
assert sampling_metadata.temperature is None
|
||||
return sampling_metadata
|
||||
|
||||
@staticmethod
|
||||
def _get_default_params_values():
|
||||
return dict(
|
||||
# Since #13587 greedy sampling requires branching off which leads
|
||||
# to separate graphs. We set temp to noop and handle argmax here.
|
||||
temperature=(1.0, torch.float32),
|
||||
min_p=(0.0, torch.float32),
|
||||
# strictly disabled for now
|
||||
# top_k=(-1, torch.int32),
|
||||
# top_p=(0.0, torch.float32),
|
||||
# frequency_penalties=(0.0, torch.float32),
|
||||
# presence_penalties=(0.0, torch.float32),
|
||||
# repetition_penalties=(0.0, torch.float32),
|
||||
)
|
||||
# Slice persistent device tensors to a fixed pre-compiled padded shape.
|
||||
return cls(
|
||||
temperature=input_batch.temperature[:padded_num_reqs],
|
||||
# Scalar tensor for xla-friendly tracing.
|
||||
all_greedy=torch.tensor(input_batch.all_greedy,
|
||||
dtype=torch.bool,
|
||||
device=input_batch.device),
|
||||
# TODO enable more and avoid returning None values
|
||||
top_p=None, # input_batch.top_p[:padded_num_reqs],
|
||||
top_k=None, # input_batch.top_k[:padded_num_reqs],
|
||||
min_p=input_batch.min_p[:padded_num_reqs],
|
||||
generators=input_batch.generators,
|
||||
indices_do_sample=indices_do_sample)
|
||||
|
||||
@ -279,9 +279,6 @@ class TPUModelRunner:
|
||||
req_data.num_computed_tokens)
|
||||
self.input_batch.block_table.append_row(req_data.new_block_ids,
|
||||
req_index)
|
||||
# Check if the batch has changed. If not, we can skip copying the
|
||||
# sampling metadata from CPU to GPU.
|
||||
batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0
|
||||
|
||||
# Add the new or resumed requests to the persistent batch.
|
||||
# The smaller empty indices are filled first.
|
||||
@ -300,9 +297,6 @@ class TPUModelRunner:
|
||||
if removed_req_indices:
|
||||
self.input_batch.condense(removed_req_indices)
|
||||
|
||||
# TODO This slices tensors to copy to device, triggering recompilation.
|
||||
if batch_changed:
|
||||
self.input_batch.refresh_sampling_metadata()
|
||||
return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
@ -597,14 +591,12 @@ class TPUModelRunner:
|
||||
# then the embedding layer is not included in the CUDA graph.
|
||||
input_ids = self.input_ids
|
||||
inputs_embeds = None
|
||||
sampling_metadata = self.input_batch.sampling_metadata
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
# NOTE (NickLucche) here we sync with TPU: if there's any shape
|
||||
# mismatch in pre-processing, it will trigger a small recompilation
|
||||
# of the code thus far. Forward graph remains untouched.
|
||||
# NOTE (NickLucche) here we sync with TPU: sampling params tensors
|
||||
# are copied to device in chunks of pre-compiled padded shape to
|
||||
# avoid recompilations.
|
||||
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
|
||||
from_sampling_metadata(sampling_metadata, logits_indices,
|
||||
num_reqs, self.device)
|
||||
from_input_batch(self.input_batch, logits_indices)
|
||||
# Run the decoder
|
||||
with set_forward_context(attn_metadata, self.vllm_config):
|
||||
hidden_states = self.model(
|
||||
@ -797,21 +789,19 @@ class TPUModelRunner:
|
||||
device=device,
|
||||
dtype=torch.bfloat16)
|
||||
while True:
|
||||
# Default metadata is an all_greedy setup. But since the
|
||||
# `do_argmax` flag is a tensor, we still compile the full graph
|
||||
meta = self.input_batch.sampling_metadata
|
||||
indices = torch.zeros(
|
||||
num_reqs_to_sample,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
xm.mark_step()
|
||||
sampling_meta = TPUSupportedSamplingMetadata.\
|
||||
from_sampling_metadata(meta, indices,
|
||||
num_reqs_to_sample, device)
|
||||
from_input_batch(self.input_batch, indices)
|
||||
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
|
||||
num_reqs_to_sample)
|
||||
self.model.sample_from_hidden(dummy_hidden, sampling_meta)
|
||||
xm.mark_step()
|
||||
out = self.model.sample_from_hidden(dummy_hidden,
|
||||
sampling_meta)
|
||||
out = out.cpu()
|
||||
if num_reqs_to_sample >= self.max_num_reqs:
|
||||
break
|
||||
num_reqs_to_sample *= 2
|
||||
@ -910,6 +900,7 @@ class ModelWrapperV1(nn.Module):
|
||||
|
||||
return hidden_states
|
||||
|
||||
# @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
|
||||
def sample_from_hidden(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -923,10 +914,9 @@ class ModelWrapperV1(nn.Module):
|
||||
sample_hidden_states = \
|
||||
hidden_states[sampling_metadata.indices_do_sample]
|
||||
logits = self.compute_logits(sample_hidden_states)
|
||||
# Greedy sampling can't be run without branching the graph on Sampler.
|
||||
# Therefore do_argmax/all_greedy is checked here in a xla-friendly way.
|
||||
# NOTE do_argmax is a scalar, this is just an optimized if/else.
|
||||
out_tokens = torch.where(sampling_metadata.do_argmax,
|
||||
# Optimized greedy sampling branch, tracing both paths in a single pass
|
||||
# NOTE all_greedy is a scalar, this is just an optimized if/else.
|
||||
out_tokens = torch.where(sampling_metadata.all_greedy,
|
||||
torch.argmax(logits, dim=-1, keepdim=True),
|
||||
self.sample(logits, sampling_metadata)\
|
||||
.sampled_token_ids)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user