[TPU] Remove redundant input tensor cloning (#7660)

This commit is contained in:
Woosuk Kwon 2024-08-19 15:55:04 -07:00 committed by GitHub
parent da115230fd
commit 43735bf5e1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -516,27 +516,19 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
raise ValueError(
"TPUModelRunner does not support multi-step execution.")
def _execute_model(*args, clone: bool = False) -> torch.Tensor:
def _execute_model(*args):
"""Move input args from CPU to device and execute the model."""
def _copy_to_device(x: torch.Tensor) -> torch.Tensor:
if clone:
# When x is a slice of a CPU tensor, XLA may copy the whole
# original tensor to TPU instead of only copying x.
# To avoid this, we copy x after cloning.
x = x.clone()
return x.to(self.device)
new_args = []
for arg in args:
if isinstance(arg, torch.Tensor):
arg = _copy_to_device(arg)
arg = arg.to(self.device)
elif isinstance(arg, AttentionMetadata):
arg.slot_mapping = _copy_to_device(arg.slot_mapping)
arg.slot_mapping = arg.slot_mapping.to(self.device)
if getattr(arg, "block_tables", None) is not None:
arg.block_tables = _copy_to_device(arg.block_tables)
arg.block_tables = arg.block_tables.to(self.device)
if getattr(arg, "context_lens", None) is not None:
arg.context_lens = _copy_to_device(arg.context_lens)
arg.context_lens = arg.context_lens.to(self.device)
new_args.append(arg)
return self.model(*new_args)
@ -563,13 +555,9 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
output_token_ids = _execute_model(
model_input.token_ids[None, start_idx:end_idx],
model_input.position_ids[None, start_idx:end_idx],
model_input.attn_metadata,
model_input.input_lens[i:i + 1],
model_input.t[i:i + 1],
model_input.p[i:i + 1],
model_input.num_samples,
kv_caches,
clone=True)
model_input.attn_metadata, model_input.input_lens[i:i + 1],
model_input.t[i:i + 1], model_input.p[i:i + 1],
model_input.num_samples, kv_caches)
# Retrieve the outputs to CPU.
next_token_ids += output_token_ids.cpu().tolist()
start_idx = end_idx