diff --git a/vllm/config.py b/vllm/config.py index 9fed21f5c61e6..a3430f23a4583 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -204,10 +204,10 @@ class SchedulerConfig: """ def __init__(self, max_num_batched_tokens: int, max_num_seqs: int, - max_seq_len: int) -> None: + max_model_len: int) -> None: self.max_num_batched_tokens = max_num_batched_tokens self.max_num_seqs = max_num_seqs - self.max_seq_len = max_seq_len + self.max_model_len = max_model_len _STR_DTYPE_TO_TORCH_DTYPE = { diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index d6379f93c5a58..2b37a3af26295 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -190,7 +190,9 @@ class Scheduler: break num_prompt_tokens = seq_group.get_seqs()[0].get_len() - if num_prompt_tokens > self.scheduler_config.max_seq_len: + if num_prompt_tokens > min( + self.scheduler_config.max_model_len, + self.scheduler_config.max_num_batched_tokens): logger.warning( f"Input prompt ({num_prompt_tokens} tokens) is too long" " and exceeds limit of " diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 463de6f70c91f..ce1f0f4ece877 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -155,11 +155,10 @@ class EngineArgs: parallel_config = ParallelConfig(self.pipeline_parallel_size, self.tensor_parallel_size, self.worker_use_ray) - model_max_len = getattr(model_config.hf_config, + max_model_len = getattr(model_config.hf_config, 'max_position_embeddings', float('inf')) - max_seq_len = min(self.max_num_batched_tokens, model_max_len) scheduler_config = SchedulerConfig(self.max_num_batched_tokens, - self.max_num_seqs, max_seq_len) + self.max_num_seqs, max_model_len) return model_config, cache_config, parallel_config, scheduler_config diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1fdb2d04d53a8..ea7e79c7e6d1e 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -300,8 +300,7 @@ class LLMEngine: continue # Check if the sequence has reached max_seq_len. - if (seq.get_len() > - self.scheduler.scheduler_config.max_seq_len): + if seq.get_len() > self.scheduler_config.max_model_len: self.scheduler.free_seq( seq, SequenceStatus.FINISHED_LENGTH_CAPPED) continue