mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-07 01:27:02 +08:00
forward context format
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
8332924320
commit
243eac58a4
@ -622,9 +622,8 @@ class EngineArgs:
|
||||
parallel_group.add_argument(
|
||||
"--enable-expert-parallel",
|
||||
**parallel_kwargs["enable_expert_parallel"])
|
||||
parallel_group.add_argument(
|
||||
"--enable-microbatching",
|
||||
**parallel_kwargs["enable_microbatching"])
|
||||
parallel_group.add_argument("--enable-microbatching",
|
||||
**parallel_kwargs["enable_microbatching"])
|
||||
parallel_group.add_argument(
|
||||
"--max-parallel-loading-workers",
|
||||
**parallel_kwargs["max_parallel_loading_workers"])
|
||||
|
||||
@ -58,11 +58,11 @@ def get_forward_context() -> ForwardContext:
|
||||
"Please use `set_forward_context` to set the forward context.")
|
||||
return _forward_context
|
||||
|
||||
|
||||
def create_forward_context(attn_metadata: Any,
|
||||
vllm_config: VllmConfig,
|
||||
virtual_engine: int = 0,
|
||||
num_tokens: int = 0
|
||||
):
|
||||
num_tokens: int = 0):
|
||||
dp_metadata: Optional[DPMetadata] = None
|
||||
if vllm_config.parallel_config.data_parallel_size > 1:
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
@ -87,12 +87,12 @@ def create_forward_context(attn_metadata: Any,
|
||||
dp_metadata = DPMetadata(max_tokens_across_dp_cpu,
|
||||
cu_tokens_across_dp_cpu)
|
||||
|
||||
return ForwardContext(
|
||||
no_compile_layers=vllm_config.compilation_config.
|
||||
static_forward_context,
|
||||
virtual_engine=virtual_engine,
|
||||
attn_metadata=attn_metadata,
|
||||
dp_metadata=dp_metadata)
|
||||
return ForwardContext(no_compile_layers=vllm_config.compilation_config.
|
||||
static_forward_context,
|
||||
virtual_engine=virtual_engine,
|
||||
attn_metadata=attn_metadata,
|
||||
dp_metadata=dp_metadata)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def override_forward_context(forward_context: Optional[ForwardContext]):
|
||||
@ -123,8 +123,8 @@ def set_forward_context(attn_metadata: Any,
|
||||
if need_to_track_batchsize:
|
||||
forward_start_time = time.perf_counter()
|
||||
|
||||
forward_context = create_forward_context(
|
||||
attn_metadata, vllm_config, virtual_engine, num_tokens)
|
||||
forward_context = create_forward_context(attn_metadata, vllm_config,
|
||||
virtual_engine, num_tokens)
|
||||
|
||||
try:
|
||||
with override_forward_context(forward_context):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user