[Misc] Misc code cleanup/simplification (#23304)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-08-21 10:22:55 -07:00 committed by GitHub
parent 10f535c086
commit 603fbbbce0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 51 additions and 57 deletions

View File

@ -91,7 +91,7 @@ class Sampler(nn.Module):
logits = self.apply_bad_words(logits, sampling_metadata) logits = self.apply_bad_words(logits, sampling_metadata)
# Apply logits processors which can impact greedy sampling # Apply logits processors which can impact greedy sampling
for processor in (sampling_metadata.logitsprocs.non_argmax_invariant): for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
logits = processor.apply(logits) logits = processor.apply(logits)
# Apply penalties (e.g., min_tokens, freq_penalties). # Apply penalties (e.g., min_tokens, freq_penalties).

View File

@ -442,10 +442,11 @@ class InputBatch:
# LoRA # LoRA
lora_id = self.request_lora_mapping[req_index] lora_id = self.request_lora_mapping[req_index]
if lora_id != 0: if lora_id != 0:
self.lora_id_to_request_ids[lora_id].discard(req_id) lora_req_ids = self.lora_id_to_request_ids[lora_id]
if len(self.lora_id_to_request_ids[lora_id]) == 0: lora_req_ids.discard(req_id)
self.lora_id_to_request_ids.pop(lora_id) if not lora_req_ids:
self.lora_id_to_lora_request.pop(lora_id) del self.lora_id_to_request_ids[lora_id]
del self.lora_id_to_lora_request[lora_id]
self.request_lora_mapping[req_index] = 0 self.request_lora_mapping[req_index] = 0
self.has_allowed_token_ids.discard(req_id) self.has_allowed_token_ids.discard(req_id)

View File

@ -358,6 +358,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if num_pooling_reqs == 0: if num_pooling_reqs == 0:
return model_kwargs return model_kwargs
# This does nontrivial work.
pooling_params = self.input_batch.pooling_metadata.pooling_params pooling_params = self.input_batch.pooling_metadata.pooling_params
assert num_pooling_reqs == num_reqs assert num_pooling_reqs == num_reqs
@ -465,7 +466,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for req_id in unscheduled_req_ids: for req_id in unscheduled_req_ids:
self.input_batch.remove_request(req_id) self.input_batch.remove_request(req_id)
req_ids_to_add: list[str] = [] reqs_to_add: list[CachedRequestState] = []
# Add new requests to the cached states. # Add new requests to the cached states.
for new_req_data in scheduler_output.scheduled_new_reqs: for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id req_id = new_req_data.req_id
@ -480,14 +481,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
generator = None generator = None
if pooling_params: if pooling_params:
assert (task := pooling_params.task) is not None, ( task = pooling_params.task
"You did not set `task` in the API") assert task is not None, "You did not set `task` in the API"
model = cast(VllmModelForPooling, self.get_model()) model = cast(VllmModelForPooling, self.get_model())
to_update = model.pooler.get_pooling_updates(task) to_update = model.pooler.get_pooling_updates(task)
to_update.apply(pooling_params) to_update.apply(pooling_params)
self.requests[req_id] = CachedRequestState( req_state = CachedRequestState(
req_id=req_id, req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids, prompt_token_ids=new_req_data.prompt_token_ids,
mm_kwargs=new_req_data.mm_kwargs, mm_kwargs=new_req_data.mm_kwargs,
@ -501,6 +502,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
lora_request=new_req_data.lora_request, lora_request=new_req_data.lora_request,
) )
self.requests[req_id] = req_state
# Only relevant for models using M-RoPE (e.g, Qwen2-VL) # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope: if self.uses_mrope:
image_grid_thw = [] image_grid_thw = []
@ -508,29 +511,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
second_per_grid_ts = [] second_per_grid_ts = []
audio_feature_lengths = [] audio_feature_lengths = []
use_audio_in_video = False use_audio_in_video = False
for mm_item in self.requests[req_id].mm_kwargs: for mm_item in req_state.mm_kwargs:
mm_input = mm_item.get_data() mm_input = mm_item.get_data()
if mm_input.get("image_grid_thw") is not None: if (t := mm_input.get("image_grid_thw")) is not None:
image_grid_thw.append( image_grid_thw.append(t.tolist())
mm_input["image_grid_thw"].tolist()) if (t := mm_input.get("video_grid_thw")) is not None:
if mm_input.get("video_grid_thw") is not None: video_grid_thw.append(t.tolist())
video_grid_thw.append( if (t := mm_input.get("second_per_grid_ts")) is not None:
mm_input["video_grid_thw"].tolist()) second_per_grid_ts.append(t)
if mm_input.get("second_per_grid_ts") is not None: if (t :=
second_per_grid_ts.append( mm_input.get("audio_feature_lengths")) is not None:
mm_input["second_per_grid_ts"]) audio_feature_lengths.append(t)
if mm_input.get("audio_feature_lengths") is not None:
audio_feature_lengths.append(
mm_input["audio_feature_lengths"])
if mm_input.get("use_audio_in_video") is True: if mm_input.get("use_audio_in_video") is True:
use_audio_in_video = True use_audio_in_video = True
hf_config = self.model_config.hf_config hf_config = self.model_config.hf_config
self.requests[req_id].mrope_positions, \ req_state.mrope_positions, req_state.mrope_position_delta = \
self.requests[req_id].mrope_position_delta = \
MRotaryEmbedding.get_input_positions_tensor( MRotaryEmbedding.get_input_positions_tensor(
self.requests[req_id].prompt_token_ids, req_state.prompt_token_ids,
hf_config=hf_config, hf_config=hf_config,
image_grid_thw=image_grid_thw, image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw, video_grid_thw=video_grid_thw,
@ -539,7 +538,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
use_audio_in_video=use_audio_in_video, use_audio_in_video=use_audio_in_video,
) )
req_ids_to_add.append(req_id) reqs_to_add.append(req_state)
# Update the states of the running/resumed requests. # Update the states of the running/resumed requests.
is_last_rank = get_pp_group().is_last_rank is_last_rank = get_pp_group().is_last_rank
@ -587,7 +586,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# The request is not in the persistent batch. # The request is not in the persistent batch.
# The request was either preempted and resumed later, or was not # The request was either preempted and resumed later, or was not
# scheduled in the previous step and needs to be added again. # scheduled in the previous step and needs to be added again.
req_ids_to_add.append(req_id) reqs_to_add.append(req_state)
continue continue
# Update the persistent batch. # Update the persistent batch.
@ -624,9 +623,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Add the new or resumed requests to the persistent batch. # Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first. # The smaller empty indices are filled first.
for req_id in req_ids_to_add: for request in reqs_to_add:
req_state = self.requests[req_id] self.input_batch.add_request(request)
self.input_batch.add_request(req_state)
# Condense the batched states if there are gaps left by removed requests # Condense the batched states if there are gaps left by removed requests
self.input_batch.condense() self.input_batch.condense()
@ -639,38 +637,32 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> BatchedTensorInputs: ) -> BatchedTensorInputs:
if self.is_multimodal_raw_input_supported: # noqa: SIM102 if not self.is_multimodal_raw_input_supported or not scheduler_output: # noqa: SIM102
if scheduler_output: return {}
mm_kwargs = list[MultiModalKwargsItem]()
for req in scheduler_output.scheduled_new_reqs:
req_mm_kwargs = req.mm_kwargs
if not isinstance(req_mm_kwargs, list):
req_mm_kwargs = list(req_mm_kwargs)
mm_kwargs.extend(req_mm_kwargs)
# Input all modalities at once mm_kwargs = list[MultiModalKwargsItem]()
mm_kwargs_combined: BatchedTensorInputs = {} for req in scheduler_output.scheduled_new_reqs:
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( mm_kwargs.extend(req.mm_kwargs)
mm_kwargs,
device=self.device,
pin_memory=self.pin_memory,
):
mm_kwargs_combined.update(mm_kwargs_group)
return mm_kwargs_combined # Input all modalities at once
mm_kwargs_combined: BatchedTensorInputs = {}
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs,
device=self.device,
pin_memory=self.pin_memory,
):
mm_kwargs_combined.update(mm_kwargs_group)
return {} return mm_kwargs_combined
def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs: def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs:
if self.is_multimodal_raw_input_supported: if not self.is_multimodal_raw_input_supported:
mm_budget = self.mm_budget return {}
assert mm_budget is not None mm_budget = self.mm_budget
assert mm_budget is not None
dummy_modality = mm_budget.get_modality_with_max_tokens() dummy_modality = mm_budget.get_modality_with_max_tokens()
return self._get_mm_dummy_batch(dummy_modality, num_seqs)
return self._get_mm_dummy_batch(dummy_modality, num_seqs)
return {}
def _get_cumsum_and_arange( def _get_cumsum_and_arange(
self, self,
@ -1612,6 +1604,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
batch_descriptor=batch_descriptor, batch_descriptor=batch_descriptor,
), self.maybe_get_kv_connector_output( ), self.maybe_get_kv_connector_output(
scheduler_output) as kv_connector_output: scheduler_output) as kv_connector_output:
model_output = self.model( model_output = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,

View File

@ -544,7 +544,7 @@ class WorkerWrapperBase:
Arguments are passed to the worker class constructor. Arguments are passed to the worker class constructor.
""" """
kwargs = all_kwargs[self.rpc_rank] kwargs = all_kwargs[self.rpc_rank]
self.vllm_config = kwargs.get("vllm_config", None) self.vllm_config = kwargs.get("vllm_config")
assert self.vllm_config is not None, ( assert self.vllm_config is not None, (
"vllm_config is required to initialize the worker") "vllm_config is required to initialize the worker")
enable_trace_function_call_for_thread(self.vllm_config) enable_trace_function_call_for_thread(self.vllm_config)