mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 21:55:32 +08:00
[TPU] Remove redundant input tensor cloning (#7660)
This commit is contained in:
parent
da115230fd
commit
43735bf5e1
@ -516,27 +516,19 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"TPUModelRunner does not support multi-step execution.")
|
"TPUModelRunner does not support multi-step execution.")
|
||||||
|
|
||||||
def _execute_model(*args, clone: bool = False) -> torch.Tensor:
|
def _execute_model(*args):
|
||||||
"""Move input args from CPU to device and execute the model."""
|
"""Move input args from CPU to device and execute the model."""
|
||||||
|
|
||||||
def _copy_to_device(x: torch.Tensor) -> torch.Tensor:
|
|
||||||
if clone:
|
|
||||||
# When x is a slice of a CPU tensor, XLA may copy the whole
|
|
||||||
# original tensor to TPU instead of only copying x.
|
|
||||||
# To avoid this, we copy x after cloning.
|
|
||||||
x = x.clone()
|
|
||||||
return x.to(self.device)
|
|
||||||
|
|
||||||
new_args = []
|
new_args = []
|
||||||
for arg in args:
|
for arg in args:
|
||||||
if isinstance(arg, torch.Tensor):
|
if isinstance(arg, torch.Tensor):
|
||||||
arg = _copy_to_device(arg)
|
arg = arg.to(self.device)
|
||||||
elif isinstance(arg, AttentionMetadata):
|
elif isinstance(arg, AttentionMetadata):
|
||||||
arg.slot_mapping = _copy_to_device(arg.slot_mapping)
|
arg.slot_mapping = arg.slot_mapping.to(self.device)
|
||||||
if getattr(arg, "block_tables", None) is not None:
|
if getattr(arg, "block_tables", None) is not None:
|
||||||
arg.block_tables = _copy_to_device(arg.block_tables)
|
arg.block_tables = arg.block_tables.to(self.device)
|
||||||
if getattr(arg, "context_lens", None) is not None:
|
if getattr(arg, "context_lens", None) is not None:
|
||||||
arg.context_lens = _copy_to_device(arg.context_lens)
|
arg.context_lens = arg.context_lens.to(self.device)
|
||||||
new_args.append(arg)
|
new_args.append(arg)
|
||||||
return self.model(*new_args)
|
return self.model(*new_args)
|
||||||
|
|
||||||
@ -563,13 +555,9 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
output_token_ids = _execute_model(
|
output_token_ids = _execute_model(
|
||||||
model_input.token_ids[None, start_idx:end_idx],
|
model_input.token_ids[None, start_idx:end_idx],
|
||||||
model_input.position_ids[None, start_idx:end_idx],
|
model_input.position_ids[None, start_idx:end_idx],
|
||||||
model_input.attn_metadata,
|
model_input.attn_metadata, model_input.input_lens[i:i + 1],
|
||||||
model_input.input_lens[i:i + 1],
|
model_input.t[i:i + 1], model_input.p[i:i + 1],
|
||||||
model_input.t[i:i + 1],
|
model_input.num_samples, kv_caches)
|
||||||
model_input.p[i:i + 1],
|
|
||||||
model_input.num_samples,
|
|
||||||
kv_caches,
|
|
||||||
clone=True)
|
|
||||||
# Retrieve the outputs to CPU.
|
# Retrieve the outputs to CPU.
|
||||||
next_token_ids += output_token_ids.cpu().tolist()
|
next_token_ids += output_token_ids.cpu().tolist()
|
||||||
start_idx = end_idx
|
start_idx = end_idx
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user