mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:34:57 +08:00
[TPU] Remove redundant input tensor cloning (#7660)
This commit is contained in:
parent
da115230fd
commit
43735bf5e1
@ -516,27 +516,19 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
||||
raise ValueError(
|
||||
"TPUModelRunner does not support multi-step execution.")
|
||||
|
||||
def _execute_model(*args, clone: bool = False) -> torch.Tensor:
|
||||
def _execute_model(*args):
|
||||
"""Move input args from CPU to device and execute the model."""
|
||||
|
||||
def _copy_to_device(x: torch.Tensor) -> torch.Tensor:
|
||||
if clone:
|
||||
# When x is a slice of a CPU tensor, XLA may copy the whole
|
||||
# original tensor to TPU instead of only copying x.
|
||||
# To avoid this, we copy x after cloning.
|
||||
x = x.clone()
|
||||
return x.to(self.device)
|
||||
|
||||
new_args = []
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.Tensor):
|
||||
arg = _copy_to_device(arg)
|
||||
arg = arg.to(self.device)
|
||||
elif isinstance(arg, AttentionMetadata):
|
||||
arg.slot_mapping = _copy_to_device(arg.slot_mapping)
|
||||
arg.slot_mapping = arg.slot_mapping.to(self.device)
|
||||
if getattr(arg, "block_tables", None) is not None:
|
||||
arg.block_tables = _copy_to_device(arg.block_tables)
|
||||
arg.block_tables = arg.block_tables.to(self.device)
|
||||
if getattr(arg, "context_lens", None) is not None:
|
||||
arg.context_lens = _copy_to_device(arg.context_lens)
|
||||
arg.context_lens = arg.context_lens.to(self.device)
|
||||
new_args.append(arg)
|
||||
return self.model(*new_args)
|
||||
|
||||
@ -563,13 +555,9 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
||||
output_token_ids = _execute_model(
|
||||
model_input.token_ids[None, start_idx:end_idx],
|
||||
model_input.position_ids[None, start_idx:end_idx],
|
||||
model_input.attn_metadata,
|
||||
model_input.input_lens[i:i + 1],
|
||||
model_input.t[i:i + 1],
|
||||
model_input.p[i:i + 1],
|
||||
model_input.num_samples,
|
||||
kv_caches,
|
||||
clone=True)
|
||||
model_input.attn_metadata, model_input.input_lens[i:i + 1],
|
||||
model_input.t[i:i + 1], model_input.p[i:i + 1],
|
||||
model_input.num_samples, kv_caches)
|
||||
# Retrieve the outputs to CPU.
|
||||
next_token_ids += output_token_ids.cpu().tolist()
|
||||
start_idx = end_idx
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user