mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 02:45:01 +08:00
[Fix] Add model sequence length into model config (#575)
This commit is contained in:
parent
82ad323dee
commit
58a072be15
@ -109,6 +109,26 @@ class ModelConfig:
|
|||||||
total_num_attention_heads = self.hf_config.num_attention_heads
|
total_num_attention_heads = self.hf_config.num_attention_heads
|
||||||
return total_num_attention_heads // parallel_config.tensor_parallel_size
|
return total_num_attention_heads // parallel_config.tensor_parallel_size
|
||||||
|
|
||||||
|
def get_max_model_len(self) -> int:
|
||||||
|
max_model_len = float("inf")
|
||||||
|
possible_keys = [
|
||||||
|
# OPT
|
||||||
|
"max_position_embeddings",
|
||||||
|
# GPT-2
|
||||||
|
"n_positions",
|
||||||
|
# MPT
|
||||||
|
"max_seq_len",
|
||||||
|
# Others
|
||||||
|
"max_sequence_length",
|
||||||
|
"max_seq_length",
|
||||||
|
"seq_len",
|
||||||
|
]
|
||||||
|
for key in possible_keys:
|
||||||
|
max_len_key = getattr(self.hf_config, key, None)
|
||||||
|
if max_len_key is not None:
|
||||||
|
max_model_len = min(max_model_len, max_len_key)
|
||||||
|
return max_model_len
|
||||||
|
|
||||||
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
|
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
|
||||||
total_num_hidden_layers = self.hf_config.num_hidden_layers
|
total_num_hidden_layers = self.hf_config.num_hidden_layers
|
||||||
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
|
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
|
||||||
|
|||||||
@ -155,10 +155,9 @@ class EngineArgs:
|
|||||||
parallel_config = ParallelConfig(self.pipeline_parallel_size,
|
parallel_config = ParallelConfig(self.pipeline_parallel_size,
|
||||||
self.tensor_parallel_size,
|
self.tensor_parallel_size,
|
||||||
self.worker_use_ray)
|
self.worker_use_ray)
|
||||||
max_model_len = getattr(model_config.hf_config,
|
|
||||||
'max_position_embeddings', float('inf'))
|
|
||||||
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
|
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
|
||||||
self.max_num_seqs, max_model_len)
|
self.max_num_seqs,
|
||||||
|
model_config.get_max_model_len())
|
||||||
return model_config, cache_config, parallel_config, scheduler_config
|
return model_config, cache_config, parallel_config, scheduler_config
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -107,25 +107,14 @@ async def get_gen_prompt(request) -> str:
|
|||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
async def check_length(request, prompt, model_config):
|
async def check_length(request, prompt):
|
||||||
if hasattr(model_config.hf_config, "max_sequence_length"):
|
|
||||||
context_len = model_config.hf_config.max_sequence_length
|
|
||||||
elif hasattr(model_config.hf_config, "seq_length"):
|
|
||||||
context_len = model_config.hf_config.seq_length
|
|
||||||
elif hasattr(model_config.hf_config, "max_position_embeddings"):
|
|
||||||
context_len = model_config.hf_config.max_position_embeddings
|
|
||||||
elif hasattr(model_config.hf_config, "seq_length"):
|
|
||||||
context_len = model_config.hf_config.seq_length
|
|
||||||
else:
|
|
||||||
context_len = 2048
|
|
||||||
|
|
||||||
input_ids = tokenizer(prompt).input_ids
|
input_ids = tokenizer(prompt).input_ids
|
||||||
token_num = len(input_ids)
|
token_num = len(input_ids)
|
||||||
|
|
||||||
if token_num + request.max_tokens > context_len:
|
if token_num + request.max_tokens > max_model_len:
|
||||||
return create_error_response(
|
return create_error_response(
|
||||||
HTTPStatus.BAD_REQUEST,
|
HTTPStatus.BAD_REQUEST,
|
||||||
f"This model's maximum context length is {context_len} tokens. "
|
f"This model's maximum context length is {max_model_len} tokens. "
|
||||||
f"However, you requested {request.max_tokens + token_num} tokens "
|
f"However, you requested {request.max_tokens + token_num} tokens "
|
||||||
f"({token_num} in the messages, "
|
f"({token_num} in the messages, "
|
||||||
f"{request.max_tokens} in the completion). "
|
f"{request.max_tokens} in the completion). "
|
||||||
@ -194,7 +183,7 @@ async def create_chat_completion(raw_request: Request):
|
|||||||
"logit_bias is not currently supported")
|
"logit_bias is not currently supported")
|
||||||
|
|
||||||
prompt = await get_gen_prompt(request)
|
prompt = await get_gen_prompt(request)
|
||||||
error_check_ret = await check_length(request, prompt, engine_model_config)
|
error_check_ret = await check_length(request, prompt)
|
||||||
if error_check_ret is not None:
|
if error_check_ret is not None:
|
||||||
return error_check_ret
|
return error_check_ret
|
||||||
|
|
||||||
@ -591,6 +580,7 @@ if __name__ == "__main__":
|
|||||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
engine_model_config = asyncio.run(engine.get_model_config())
|
engine_model_config = asyncio.run(engine.get_model_config())
|
||||||
|
max_model_len = engine_model_config.get_max_model_len()
|
||||||
|
|
||||||
# A separate tokenizer to map token IDs to strings.
|
# A separate tokenizer to map token IDs to strings.
|
||||||
tokenizer = get_tokenizer(engine_args.tokenizer,
|
tokenizer = get_tokenizer(engine_args.tokenizer,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user