From 85d44884583eb4c4f6060b7441f01ef51b5623a5 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 26 Apr 2024 05:31:31 +0000 Subject: [PATCH] yapf --- vllm/worker/tpu_model_runner.py | 3 ++- vllm/worker/tpu_worker.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 89390538a4ab3..67db69c2cdf7a 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -157,7 +157,8 @@ class TPUModelRunner: pad=_PAD_SLOT_ID, dtype=jnp.int32) prompt_lens = jnp.asarray(prompt_lens, dtype=jnp.int32) - return input_tokens, input_positions, slot_mapping, None, None, prompt_lens + return (input_tokens, input_positions, slot_mapping, None, None, + prompt_lens) def _prepare_decode( self, diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 7f5d7efe57880..8c0f2ef7acd6f 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Tuple import jax.numpy as jnp import torch -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.logger import init_logger from vllm.model_executor import set_random_seed