diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index e1a9eb179e7ae..1c6df960ebe36 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -622,9 +622,8 @@ class EngineArgs: parallel_group.add_argument( "--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"]) - parallel_group.add_argument( - "--enable-microbatching", - **parallel_kwargs["enable_microbatching"]) + parallel_group.add_argument("--enable-microbatching", + **parallel_kwargs["enable_microbatching"]) parallel_group.add_argument( "--max-parallel-loading-workers", **parallel_kwargs["max_parallel_loading_workers"]) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index fce9245d15bbd..bb43302c323b3 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -58,11 +58,11 @@ def get_forward_context() -> ForwardContext: "Please use `set_forward_context` to set the forward context.") return _forward_context + def create_forward_context(attn_metadata: Any, vllm_config: VllmConfig, virtual_engine: int = 0, - num_tokens: int = 0 -): + num_tokens: int = 0): dp_metadata: Optional[DPMetadata] = None if vllm_config.parallel_config.data_parallel_size > 1: dp_size = vllm_config.parallel_config.data_parallel_size @@ -87,12 +87,12 @@ def create_forward_context(attn_metadata: Any, dp_metadata = DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu) - return ForwardContext( - no_compile_layers=vllm_config.compilation_config. - static_forward_context, - virtual_engine=virtual_engine, - attn_metadata=attn_metadata, - dp_metadata=dp_metadata) + return ForwardContext(no_compile_layers=vllm_config.compilation_config. + static_forward_context, + virtual_engine=virtual_engine, + attn_metadata=attn_metadata, + dp_metadata=dp_metadata) + @contextmanager def override_forward_context(forward_context: Optional[ForwardContext]): @@ -123,8 +123,8 @@ def set_forward_context(attn_metadata: Any, if need_to_track_batchsize: forward_start_time = time.perf_counter() - forward_context = create_forward_context( - attn_metadata, vllm_config, virtual_engine, num_tokens) + forward_context = create_forward_context(attn_metadata, vllm_config, + virtual_engine, num_tokens) try: with override_forward_context(forward_context):