mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 01:55:36 +08:00
fix max seq len (#489)
This commit is contained in:
parent
20b0d88d16
commit
b4b195b360
@ -204,10 +204,10 @@ class SchedulerConfig:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, max_num_batched_tokens: int, max_num_seqs: int,
|
def __init__(self, max_num_batched_tokens: int, max_num_seqs: int,
|
||||||
max_seq_len: int) -> None:
|
max_model_len: int) -> None:
|
||||||
self.max_num_batched_tokens = max_num_batched_tokens
|
self.max_num_batched_tokens = max_num_batched_tokens
|
||||||
self.max_num_seqs = max_num_seqs
|
self.max_num_seqs = max_num_seqs
|
||||||
self.max_seq_len = max_seq_len
|
self.max_model_len = max_model_len
|
||||||
|
|
||||||
|
|
||||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||||
|
|||||||
@ -190,7 +190,9 @@ class Scheduler:
|
|||||||
break
|
break
|
||||||
|
|
||||||
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
|
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
|
||||||
if num_prompt_tokens > self.scheduler_config.max_seq_len:
|
if num_prompt_tokens > min(
|
||||||
|
self.scheduler_config.max_model_len,
|
||||||
|
self.scheduler_config.max_num_batched_tokens):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Input prompt ({num_prompt_tokens} tokens) is too long"
|
f"Input prompt ({num_prompt_tokens} tokens) is too long"
|
||||||
" and exceeds limit of "
|
" and exceeds limit of "
|
||||||
|
|||||||
@ -155,11 +155,10 @@ class EngineArgs:
|
|||||||
parallel_config = ParallelConfig(self.pipeline_parallel_size,
|
parallel_config = ParallelConfig(self.pipeline_parallel_size,
|
||||||
self.tensor_parallel_size,
|
self.tensor_parallel_size,
|
||||||
self.worker_use_ray)
|
self.worker_use_ray)
|
||||||
model_max_len = getattr(model_config.hf_config,
|
max_model_len = getattr(model_config.hf_config,
|
||||||
'max_position_embeddings', float('inf'))
|
'max_position_embeddings', float('inf'))
|
||||||
max_seq_len = min(self.max_num_batched_tokens, model_max_len)
|
|
||||||
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
|
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
|
||||||
self.max_num_seqs, max_seq_len)
|
self.max_num_seqs, max_model_len)
|
||||||
return model_config, cache_config, parallel_config, scheduler_config
|
return model_config, cache_config, parallel_config, scheduler_config
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -300,8 +300,7 @@ class LLMEngine:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Check if the sequence has reached max_seq_len.
|
# Check if the sequence has reached max_seq_len.
|
||||||
if (seq.get_len() >
|
if seq.get_len() > self.scheduler_config.max_model_len:
|
||||||
self.scheduler.scheduler_config.max_seq_len):
|
|
||||||
self.scheduler.free_seq(
|
self.scheduler.free_seq(
|
||||||
seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
|
seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
|
||||||
continue
|
continue
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user