forward context format

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-06-02 19:16:06 +00:00
parent 8332924320
commit 243eac58a4
2 changed files with 12 additions and 13 deletions

View File

@ -622,9 +622,8 @@ class EngineArgs:
parallel_group.add_argument(
"--enable-expert-parallel",
**parallel_kwargs["enable_expert_parallel"])
parallel_group.add_argument(
"--enable-microbatching",
**parallel_kwargs["enable_microbatching"])
parallel_group.add_argument("--enable-microbatching",
**parallel_kwargs["enable_microbatching"])
parallel_group.add_argument(
"--max-parallel-loading-workers",
**parallel_kwargs["max_parallel_loading_workers"])

View File

@ -58,11 +58,11 @@ def get_forward_context() -> ForwardContext:
"Please use `set_forward_context` to set the forward context.")
return _forward_context
def create_forward_context(attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0,
num_tokens: int = 0
):
num_tokens: int = 0):
dp_metadata: Optional[DPMetadata] = None
if vllm_config.parallel_config.data_parallel_size > 1:
dp_size = vllm_config.parallel_config.data_parallel_size
@ -87,12 +87,12 @@ def create_forward_context(attn_metadata: Any,
dp_metadata = DPMetadata(max_tokens_across_dp_cpu,
cu_tokens_across_dp_cpu)
return ForwardContext(
no_compile_layers=vllm_config.compilation_config.
static_forward_context,
virtual_engine=virtual_engine,
attn_metadata=attn_metadata,
dp_metadata=dp_metadata)
return ForwardContext(no_compile_layers=vllm_config.compilation_config.
static_forward_context,
virtual_engine=virtual_engine,
attn_metadata=attn_metadata,
dp_metadata=dp_metadata)
@contextmanager
def override_forward_context(forward_context: Optional[ForwardContext]):
@ -123,8 +123,8 @@ def set_forward_context(attn_metadata: Any,
if need_to_track_batchsize:
forward_start_time = time.perf_counter()
forward_context = create_forward_context(
attn_metadata, vllm_config, virtual_engine, num_tokens)
forward_context = create_forward_context(attn_metadata, vllm_config,
virtual_engine, num_tokens)
try:
with override_forward_context(forward_context):