mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-24 06:59:16 +08:00
[Chore] remove duplicate code
This commit is contained in:
parent
f74bb82909
commit
26ddfa299c
@ -642,25 +642,6 @@ class GPUModelRunner(
|
||||
with_stack=False,
|
||||
)
|
||||
|
||||
profile_dir = (
|
||||
"./profiler_logs/attn"
|
||||
if self.afd_config is not None and self.afd_config.afd_role == "attention"
|
||||
else "./profiler_logs/normal"
|
||||
)
|
||||
self.profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
schedule=torch.profiler.schedule(
|
||||
wait=6000 + 4000, warmup=1, active=30, repeat=1
|
||||
),
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(profile_dir),
|
||||
record_shapes=True,
|
||||
profile_memory=False,
|
||||
with_stack=False,
|
||||
)
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
if self.mm_budget:
|
||||
self.mm_budget.reset_cache()
|
||||
@ -2988,38 +2969,6 @@ class GPUModelRunner(
|
||||
)
|
||||
return afd_metadata
|
||||
|
||||
def _build_afd_metadata(
|
||||
self, ubatch_slices: UBatchSlices | None, num_tokens_unpadded: int
|
||||
):
|
||||
afd_metadata = None
|
||||
if self.afd_config:
|
||||
# For prefill, compute tokens per stage based on actual token
|
||||
# counts
|
||||
afd_tokens_start_loc = [0]
|
||||
afd_tokens_lens = []
|
||||
if ubatch_slices and len(ubatch_slices) > 1:
|
||||
afd_tokens_start_loc = [ub.token_slice.start for ub in ubatch_slices]
|
||||
afd_reqs_start_loc = [ub.request_slice.start for ub in ubatch_slices]
|
||||
logger.info(
|
||||
f"afd_tokens_start_loc: {afd_tokens_start_loc} "
|
||||
f"afd_reqs_start_loc: {afd_reqs_start_loc} "
|
||||
f"ubatch_slices: {ubatch_slices}"
|
||||
)
|
||||
afd_tokens_lens = [ub.num_tokens for ub in ubatch_slices]
|
||||
else:
|
||||
afd_tokens_start_loc = [0]
|
||||
afd_reqs_start_loc = [0]
|
||||
afd_tokens_lens = [num_tokens_unpadded]
|
||||
afd_metadata = AFDMetadata(
|
||||
afd_tokens_start_loc=afd_tokens_start_loc,
|
||||
afd_reqs_start_loc=afd_reqs_start_loc,
|
||||
afd_stage_idx=0,
|
||||
afd_connector=self.afd_connector,
|
||||
afd_tokens_lens=afd_tokens_lens,
|
||||
num_of_stages=len(ubatch_slices) if ubatch_slices else 1,
|
||||
)
|
||||
return afd_metadata
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
@ -5573,11 +5522,6 @@ class GPUModelRunner(
|
||||
if hasattr(self, "afd_connector") and self.afd_connector:
|
||||
self.afd_connector.init_afd_connector()
|
||||
|
||||
def initialize_afd_connector(self) -> None:
|
||||
"""Initialize AFD connector if available."""
|
||||
if hasattr(self, "afd_connector") and self.afd_connector:
|
||||
self.afd_connector.init_afd_connector()
|
||||
|
||||
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
|
||||
"""
|
||||
Add encoder-only layers to the KV cache config.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user