mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 12:35:01 +08:00
Provide default max model length (#1224)
This commit is contained in:
parent
6f88f762bf
commit
f936657eb6
@ -164,9 +164,6 @@ class ModelConfig:
|
|||||||
total_num_attention_heads = self.hf_config.num_attention_heads
|
total_num_attention_heads = self.hf_config.num_attention_heads
|
||||||
return total_num_attention_heads // parallel_config.tensor_parallel_size
|
return total_num_attention_heads // parallel_config.tensor_parallel_size
|
||||||
|
|
||||||
def get_max_model_len(self) -> int:
|
|
||||||
return self.max_model_len
|
|
||||||
|
|
||||||
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
|
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
|
||||||
total_num_hidden_layers = self.hf_config.num_hidden_layers
|
total_num_hidden_layers = self.hf_config.num_hidden_layers
|
||||||
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
|
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
|
||||||
@ -378,10 +375,17 @@ def _get_and_verify_max_len(
|
|||||||
if max_len_key is not None:
|
if max_len_key is not None:
|
||||||
derived_max_model_len = min(derived_max_model_len, max_len_key)
|
derived_max_model_len = min(derived_max_model_len, max_len_key)
|
||||||
if derived_max_model_len == float("inf"):
|
if derived_max_model_len == float("inf"):
|
||||||
raise ValueError(
|
if max_model_len is not None:
|
||||||
"The model's config.json must contain one of the following keys "
|
# If max_model_len is specified, we use it.
|
||||||
"to determine the original maximum length of the model: "
|
return max_model_len
|
||||||
f"{possible_keys}")
|
|
||||||
|
default_max_len = 2048
|
||||||
|
logger.warning(
|
||||||
|
"The model's config.json does not contain any of the following "
|
||||||
|
"keys to determine the original maximum length of the model: "
|
||||||
|
f"{possible_keys}. Assuming the model's maximum length is "
|
||||||
|
f"{default_max_len}.")
|
||||||
|
derived_max_model_len = default_max_len
|
||||||
|
|
||||||
rope_scaling = getattr(hf_config, "rope_scaling", None)
|
rope_scaling = getattr(hf_config, "rope_scaling", None)
|
||||||
if rope_scaling is not None:
|
if rope_scaling is not None:
|
||||||
|
|||||||
@ -184,7 +184,7 @@ class EngineArgs:
|
|||||||
self.worker_use_ray)
|
self.worker_use_ray)
|
||||||
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
|
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
|
||||||
self.max_num_seqs,
|
self.max_num_seqs,
|
||||||
model_config.get_max_model_len())
|
model_config.max_model_len)
|
||||||
return model_config, cache_config, parallel_config, scheduler_config
|
return model_config, cache_config, parallel_config, scheduler_config
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -77,6 +77,7 @@ class LLMEngine:
|
|||||||
f"revision={model_config.revision}, "
|
f"revision={model_config.revision}, "
|
||||||
f"trust_remote_code={model_config.trust_remote_code}, "
|
f"trust_remote_code={model_config.trust_remote_code}, "
|
||||||
f"dtype={model_config.dtype}, "
|
f"dtype={model_config.dtype}, "
|
||||||
|
f"max_seq_len={model_config.max_model_len}, "
|
||||||
f"download_dir={model_config.download_dir!r}, "
|
f"download_dir={model_config.download_dir!r}, "
|
||||||
f"load_format={model_config.load_format}, "
|
f"load_format={model_config.load_format}, "
|
||||||
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
|
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
|
||||||
|
|||||||
@ -615,7 +615,7 @@ if __name__ == "__main__":
|
|||||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
engine_model_config = asyncio.run(engine.get_model_config())
|
engine_model_config = asyncio.run(engine.get_model_config())
|
||||||
max_model_len = engine_model_config.get_max_model_len()
|
max_model_len = engine_model_config.max_model_len
|
||||||
|
|
||||||
# A separate tokenizer to map token IDs to strings.
|
# A separate tokenizer to map token IDs to strings.
|
||||||
tokenizer = get_tokenizer(engine_args.tokenizer,
|
tokenizer = get_tokenizer(engine_args.tokenizer,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user