[Misc] Misc code cleanup/simplification (#23304)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-08-21 10:22:55 -07:00 committed by GitHub
parent 10f535c086
commit 603fbbbce0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 51 additions and 57 deletions

View File

@ -91,7 +91,7 @@ class Sampler(nn.Module):
logits = self.apply_bad_words(logits, sampling_metadata)
# Apply logits processors which can impact greedy sampling
for processor in (sampling_metadata.logitsprocs.non_argmax_invariant):
for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
logits = processor.apply(logits)
# Apply penalties (e.g., min_tokens, freq_penalties).

View File

@ -442,10 +442,11 @@ class InputBatch:
# LoRA
lora_id = self.request_lora_mapping[req_index]
if lora_id != 0:
self.lora_id_to_request_ids[lora_id].discard(req_id)
if len(self.lora_id_to_request_ids[lora_id]) == 0:
self.lora_id_to_request_ids.pop(lora_id)
self.lora_id_to_lora_request.pop(lora_id)
lora_req_ids = self.lora_id_to_request_ids[lora_id]
lora_req_ids.discard(req_id)
if not lora_req_ids:
del self.lora_id_to_request_ids[lora_id]
del self.lora_id_to_lora_request[lora_id]
self.request_lora_mapping[req_index] = 0
self.has_allowed_token_ids.discard(req_id)

View File

@ -358,6 +358,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if num_pooling_reqs == 0:
return model_kwargs
# This does nontrivial work.
pooling_params = self.input_batch.pooling_metadata.pooling_params
assert num_pooling_reqs == num_reqs
@ -465,7 +466,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for req_id in unscheduled_req_ids:
self.input_batch.remove_request(req_id)
req_ids_to_add: list[str] = []
reqs_to_add: list[CachedRequestState] = []
# Add new requests to the cached states.
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
@ -480,14 +481,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
generator = None
if pooling_params:
assert (task := pooling_params.task) is not None, (
"You did not set `task` in the API")
task = pooling_params.task
assert task is not None, "You did not set `task` in the API"
model = cast(VllmModelForPooling, self.get_model())
to_update = model.pooler.get_pooling_updates(task)
to_update.apply(pooling_params)
self.requests[req_id] = CachedRequestState(
req_state = CachedRequestState(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
mm_kwargs=new_req_data.mm_kwargs,
@ -501,6 +502,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
lora_request=new_req_data.lora_request,
)
self.requests[req_id] = req_state
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
image_grid_thw = []
@ -508,29 +511,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
second_per_grid_ts = []
audio_feature_lengths = []
use_audio_in_video = False
for mm_item in self.requests[req_id].mm_kwargs:
for mm_item in req_state.mm_kwargs:
mm_input = mm_item.get_data()
if mm_input.get("image_grid_thw") is not None:
image_grid_thw.append(
mm_input["image_grid_thw"].tolist())
if mm_input.get("video_grid_thw") is not None:
video_grid_thw.append(
mm_input["video_grid_thw"].tolist())
if mm_input.get("second_per_grid_ts") is not None:
second_per_grid_ts.append(
mm_input["second_per_grid_ts"])
if mm_input.get("audio_feature_lengths") is not None:
audio_feature_lengths.append(
mm_input["audio_feature_lengths"])
if (t := mm_input.get("image_grid_thw")) is not None:
image_grid_thw.append(t.tolist())
if (t := mm_input.get("video_grid_thw")) is not None:
video_grid_thw.append(t.tolist())
if (t := mm_input.get("second_per_grid_ts")) is not None:
second_per_grid_ts.append(t)
if (t :=
mm_input.get("audio_feature_lengths")) is not None:
audio_feature_lengths.append(t)
if mm_input.get("use_audio_in_video") is True:
use_audio_in_video = True
hf_config = self.model_config.hf_config
self.requests[req_id].mrope_positions, \
self.requests[req_id].mrope_position_delta = \
req_state.mrope_positions, req_state.mrope_position_delta = \
MRotaryEmbedding.get_input_positions_tensor(
self.requests[req_id].prompt_token_ids,
req_state.prompt_token_ids,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
@ -539,7 +538,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
use_audio_in_video=use_audio_in_video,
)
req_ids_to_add.append(req_id)
reqs_to_add.append(req_state)
# Update the states of the running/resumed requests.
is_last_rank = get_pp_group().is_last_rank
@ -587,7 +586,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# The request is not in the persistent batch.
# The request was either preempted and resumed later, or was not
# scheduled in the previous step and needs to be added again.
req_ids_to_add.append(req_id)
reqs_to_add.append(req_state)
continue
# Update the persistent batch.
@ -624,9 +623,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
for req_id in req_ids_to_add:
req_state = self.requests[req_id]
self.input_batch.add_request(req_state)
for request in reqs_to_add:
self.input_batch.add_request(request)
# Condense the batched states if there are gaps left by removed requests
self.input_batch.condense()
@ -639,38 +637,32 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self,
scheduler_output: "SchedulerOutput",
) -> BatchedTensorInputs:
if self.is_multimodal_raw_input_supported: # noqa: SIM102
if scheduler_output:
mm_kwargs = list[MultiModalKwargsItem]()
for req in scheduler_output.scheduled_new_reqs:
req_mm_kwargs = req.mm_kwargs
if not isinstance(req_mm_kwargs, list):
req_mm_kwargs = list(req_mm_kwargs)
mm_kwargs.extend(req_mm_kwargs)
if not self.is_multimodal_raw_input_supported or not scheduler_output: # noqa: SIM102
return {}
# Input all modalities at once
mm_kwargs_combined: BatchedTensorInputs = {}
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs,
device=self.device,
pin_memory=self.pin_memory,
):
mm_kwargs_combined.update(mm_kwargs_group)
mm_kwargs = list[MultiModalKwargsItem]()
for req in scheduler_output.scheduled_new_reqs:
mm_kwargs.extend(req.mm_kwargs)
return mm_kwargs_combined
# Input all modalities at once
mm_kwargs_combined: BatchedTensorInputs = {}
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs,
device=self.device,
pin_memory=self.pin_memory,
):
mm_kwargs_combined.update(mm_kwargs_group)
return {}
return mm_kwargs_combined
def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs:
if self.is_multimodal_raw_input_supported:
mm_budget = self.mm_budget
assert mm_budget is not None
if not self.is_multimodal_raw_input_supported:
return {}
mm_budget = self.mm_budget
assert mm_budget is not None
dummy_modality = mm_budget.get_modality_with_max_tokens()
return self._get_mm_dummy_batch(dummy_modality, num_seqs)
return {}
dummy_modality = mm_budget.get_modality_with_max_tokens()
return self._get_mm_dummy_batch(dummy_modality, num_seqs)
def _get_cumsum_and_arange(
self,
@ -1612,6 +1604,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
batch_descriptor=batch_descriptor,
), self.maybe_get_kv_connector_output(
scheduler_output) as kv_connector_output:
model_output = self.model(
input_ids=input_ids,
positions=positions,

View File

@ -544,7 +544,7 @@ class WorkerWrapperBase:
Arguments are passed to the worker class constructor.
"""
kwargs = all_kwargs[self.rpc_rank]
self.vllm_config = kwargs.get("vllm_config", None)
self.vllm_config = kwargs.get("vllm_config")
assert self.vllm_config is not None, (
"vllm_config is required to initialize the worker")
enable_trace_function_call_for_thread(self.vllm_config)