mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-05 16:31:25 +08:00
forward context format
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
8332924320
commit
243eac58a4
@ -622,9 +622,8 @@ class EngineArgs:
|
|||||||
parallel_group.add_argument(
|
parallel_group.add_argument(
|
||||||
"--enable-expert-parallel",
|
"--enable-expert-parallel",
|
||||||
**parallel_kwargs["enable_expert_parallel"])
|
**parallel_kwargs["enable_expert_parallel"])
|
||||||
parallel_group.add_argument(
|
parallel_group.add_argument("--enable-microbatching",
|
||||||
"--enable-microbatching",
|
**parallel_kwargs["enable_microbatching"])
|
||||||
**parallel_kwargs["enable_microbatching"])
|
|
||||||
parallel_group.add_argument(
|
parallel_group.add_argument(
|
||||||
"--max-parallel-loading-workers",
|
"--max-parallel-loading-workers",
|
||||||
**parallel_kwargs["max_parallel_loading_workers"])
|
**parallel_kwargs["max_parallel_loading_workers"])
|
||||||
|
|||||||
@ -58,11 +58,11 @@ def get_forward_context() -> ForwardContext:
|
|||||||
"Please use `set_forward_context` to set the forward context.")
|
"Please use `set_forward_context` to set the forward context.")
|
||||||
return _forward_context
|
return _forward_context
|
||||||
|
|
||||||
|
|
||||||
def create_forward_context(attn_metadata: Any,
|
def create_forward_context(attn_metadata: Any,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
virtual_engine: int = 0,
|
virtual_engine: int = 0,
|
||||||
num_tokens: int = 0
|
num_tokens: int = 0):
|
||||||
):
|
|
||||||
dp_metadata: Optional[DPMetadata] = None
|
dp_metadata: Optional[DPMetadata] = None
|
||||||
if vllm_config.parallel_config.data_parallel_size > 1:
|
if vllm_config.parallel_config.data_parallel_size > 1:
|
||||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||||
@ -87,12 +87,12 @@ def create_forward_context(attn_metadata: Any,
|
|||||||
dp_metadata = DPMetadata(max_tokens_across_dp_cpu,
|
dp_metadata = DPMetadata(max_tokens_across_dp_cpu,
|
||||||
cu_tokens_across_dp_cpu)
|
cu_tokens_across_dp_cpu)
|
||||||
|
|
||||||
return ForwardContext(
|
return ForwardContext(no_compile_layers=vllm_config.compilation_config.
|
||||||
no_compile_layers=vllm_config.compilation_config.
|
static_forward_context,
|
||||||
static_forward_context,
|
virtual_engine=virtual_engine,
|
||||||
virtual_engine=virtual_engine,
|
attn_metadata=attn_metadata,
|
||||||
attn_metadata=attn_metadata,
|
dp_metadata=dp_metadata)
|
||||||
dp_metadata=dp_metadata)
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def override_forward_context(forward_context: Optional[ForwardContext]):
|
def override_forward_context(forward_context: Optional[ForwardContext]):
|
||||||
@ -123,8 +123,8 @@ def set_forward_context(attn_metadata: Any,
|
|||||||
if need_to_track_batchsize:
|
if need_to_track_batchsize:
|
||||||
forward_start_time = time.perf_counter()
|
forward_start_time = time.perf_counter()
|
||||||
|
|
||||||
forward_context = create_forward_context(
|
forward_context = create_forward_context(attn_metadata, vllm_config,
|
||||||
attn_metadata, vllm_config, virtual_engine, num_tokens)
|
virtual_engine, num_tokens)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with override_forward_context(forward_context):
|
with override_forward_context(forward_context):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user