mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 12:45:01 +08:00
[Fix] Add model sequence length into model config (#575)
This commit is contained in:
parent
82ad323dee
commit
58a072be15
@ -109,6 +109,26 @@ class ModelConfig:
|
||||
total_num_attention_heads = self.hf_config.num_attention_heads
|
||||
return total_num_attention_heads // parallel_config.tensor_parallel_size
|
||||
|
||||
def get_max_model_len(self) -> int:
|
||||
max_model_len = float("inf")
|
||||
possible_keys = [
|
||||
# OPT
|
||||
"max_position_embeddings",
|
||||
# GPT-2
|
||||
"n_positions",
|
||||
# MPT
|
||||
"max_seq_len",
|
||||
# Others
|
||||
"max_sequence_length",
|
||||
"max_seq_length",
|
||||
"seq_len",
|
||||
]
|
||||
for key in possible_keys:
|
||||
max_len_key = getattr(self.hf_config, key, None)
|
||||
if max_len_key is not None:
|
||||
max_model_len = min(max_model_len, max_len_key)
|
||||
return max_model_len
|
||||
|
||||
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
|
||||
total_num_hidden_layers = self.hf_config.num_hidden_layers
|
||||
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
|
||||
|
||||
@ -155,10 +155,9 @@ class EngineArgs:
|
||||
parallel_config = ParallelConfig(self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size,
|
||||
self.worker_use_ray)
|
||||
max_model_len = getattr(model_config.hf_config,
|
||||
'max_position_embeddings', float('inf'))
|
||||
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
|
||||
self.max_num_seqs, max_model_len)
|
||||
self.max_num_seqs,
|
||||
model_config.get_max_model_len())
|
||||
return model_config, cache_config, parallel_config, scheduler_config
|
||||
|
||||
|
||||
|
||||
@ -107,25 +107,14 @@ async def get_gen_prompt(request) -> str:
|
||||
return prompt
|
||||
|
||||
|
||||
async def check_length(request, prompt, model_config):
|
||||
if hasattr(model_config.hf_config, "max_sequence_length"):
|
||||
context_len = model_config.hf_config.max_sequence_length
|
||||
elif hasattr(model_config.hf_config, "seq_length"):
|
||||
context_len = model_config.hf_config.seq_length
|
||||
elif hasattr(model_config.hf_config, "max_position_embeddings"):
|
||||
context_len = model_config.hf_config.max_position_embeddings
|
||||
elif hasattr(model_config.hf_config, "seq_length"):
|
||||
context_len = model_config.hf_config.seq_length
|
||||
else:
|
||||
context_len = 2048
|
||||
|
||||
async def check_length(request, prompt):
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
token_num = len(input_ids)
|
||||
|
||||
if token_num + request.max_tokens > context_len:
|
||||
if token_num + request.max_tokens > max_model_len:
|
||||
return create_error_response(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
f"This model's maximum context length is {context_len} tokens. "
|
||||
f"This model's maximum context length is {max_model_len} tokens. "
|
||||
f"However, you requested {request.max_tokens + token_num} tokens "
|
||||
f"({token_num} in the messages, "
|
||||
f"{request.max_tokens} in the completion). "
|
||||
@ -194,7 +183,7 @@ async def create_chat_completion(raw_request: Request):
|
||||
"logit_bias is not currently supported")
|
||||
|
||||
prompt = await get_gen_prompt(request)
|
||||
error_check_ret = await check_length(request, prompt, engine_model_config)
|
||||
error_check_ret = await check_length(request, prompt)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
@ -591,6 +580,7 @@ if __name__ == "__main__":
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
engine_model_config = asyncio.run(engine.get_model_config())
|
||||
max_model_len = engine_model_config.get_max_model_len()
|
||||
|
||||
# A separate tokenizer to map token IDs to strings.
|
||||
tokenizer = get_tokenizer(engine_args.tokenizer,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user