This commit is contained in:
Woosuk Kwon 2024-04-26 05:31:31 +00:00
parent 8d072dbfbd
commit 85d4488458
2 changed files with 3 additions and 2 deletions

View File

@ -157,7 +157,8 @@ class TPUModelRunner:
pad=_PAD_SLOT_ID,
dtype=jnp.int32)
prompt_lens = jnp.asarray(prompt_lens, dtype=jnp.int32)
return input_tokens, input_positions, slot_mapping, None, None, prompt_lens
return (input_tokens, input_positions, slot_mapping, None, None,
prompt_lens)
def _prepare_decode(
self,

View File

@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Tuple
import jax.numpy as jnp
import torch
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed