diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 14f14e40b4c0b..01daa64b5a32f 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -516,27 +516,19 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): raise ValueError( "TPUModelRunner does not support multi-step execution.") - def _execute_model(*args, clone: bool = False) -> torch.Tensor: + def _execute_model(*args): """Move input args from CPU to device and execute the model.""" - def _copy_to_device(x: torch.Tensor) -> torch.Tensor: - if clone: - # When x is a slice of a CPU tensor, XLA may copy the whole - # original tensor to TPU instead of only copying x. - # To avoid this, we copy x after cloning. - x = x.clone() - return x.to(self.device) - new_args = [] for arg in args: if isinstance(arg, torch.Tensor): - arg = _copy_to_device(arg) + arg = arg.to(self.device) elif isinstance(arg, AttentionMetadata): - arg.slot_mapping = _copy_to_device(arg.slot_mapping) + arg.slot_mapping = arg.slot_mapping.to(self.device) if getattr(arg, "block_tables", None) is not None: - arg.block_tables = _copy_to_device(arg.block_tables) + arg.block_tables = arg.block_tables.to(self.device) if getattr(arg, "context_lens", None) is not None: - arg.context_lens = _copy_to_device(arg.context_lens) + arg.context_lens = arg.context_lens.to(self.device) new_args.append(arg) return self.model(*new_args) @@ -563,13 +555,9 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): output_token_ids = _execute_model( model_input.token_ids[None, start_idx:end_idx], model_input.position_ids[None, start_idx:end_idx], - model_input.attn_metadata, - model_input.input_lens[i:i + 1], - model_input.t[i:i + 1], - model_input.p[i:i + 1], - model_input.num_samples, - kv_caches, - clone=True) + model_input.attn_metadata, model_input.input_lens[i:i + 1], + model_input.t[i:i + 1], model_input.p[i:i + 1], + model_input.num_samples, kv_caches) # Retrieve the outputs to CPU. next_token_ids += output_token_ids.cpu().tolist() start_idx = end_idx