diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 89390538a4ab3..67db69c2cdf7a 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -157,7 +157,8 @@ class TPUModelRunner: pad=_PAD_SLOT_ID, dtype=jnp.int32) prompt_lens = jnp.asarray(prompt_lens, dtype=jnp.int32) - return input_tokens, input_positions, slot_mapping, None, None, prompt_lens + return (input_tokens, input_positions, slot_mapping, None, None, + prompt_lens) def _prepare_decode( self, diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 7f5d7efe57880..8c0f2ef7acd6f 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Tuple import jax.numpy as jnp import torch -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.logger import init_logger from vllm.model_executor import set_random_seed