[Misc][V1] Misc code streamlining (#15723)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-03-28 20:59:47 -07:00 committed by GitHub
parent 762b424a52
commit 6d531ad7b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 32 additions and 38 deletions

View File

@ -207,9 +207,6 @@ class StatelessProcessGroup:
def barrier(self): def barrier(self):
"""A barrier to synchronize all ranks.""" """A barrier to synchronize all ranks."""
for i in range(self.world_size): for i in range(self.world_size):
if i == self.rank:
self.broadcast_obj(None, src=self.rank)
else:
self.broadcast_obj(None, src=i) self.broadcast_obj(None, src=i)
@staticmethod @staticmethod

View File

@ -269,29 +269,26 @@ class Scheduler(SchedulerInterface):
request = self.waiting[0] request = self.waiting[0]
# Waiting request skipping logic
is_skipped = False
# Skip request if the structured output request is still waiting # Skip request if the structured output request is still waiting
# for FSM. # for FSM compilation.
if (not is_skipped if request.status == RequestStatus.WAITING_FOR_FSM:
and request.status == RequestStatus.WAITING_FOR_FSM):
structured_output_req = request.structured_output_request structured_output_req = request.structured_output_request
is_skipped = (not structured_output_req if structured_output_req and structured_output_req.grammar:
or not structured_output_req.grammar)
if not is_skipped:
request.status = RequestStatus.WAITING request.status = RequestStatus.WAITING
else:
# Skip request if max_loras can't be honored.
if (not is_skipped and self.lora_config
and request.lora_request):
req_lora_id = request.lora_request.lora_int_id
is_skipped = (len(scheduled_loras)
== self.lora_config.max_loras
and (req_lora_id not in scheduled_loras))
if is_skipped:
skipped_waiting_requests.appendleft(request)
self.waiting.popleft() self.waiting.popleft()
skipped_waiting_requests.appendleft(request)
continue
# Check that adding the request still respects the max_loras
# constraint.
if self.lora_config and request.lora_request and (
len(scheduled_loras) == self.lora_config.max_loras
and request.lora_request.lora_int_id
not in scheduled_loras):
# Scheduling would exceed max_loras, skip.
self.waiting.popleft()
skipped_waiting_requests.appendleft(request)
continue continue
# Get already-cached tokens. # Get already-cached tokens.
@ -602,8 +599,9 @@ class Scheduler(SchedulerInterface):
# OPTIMIZATION: Avoid list(set) if the set is empty. # OPTIMIZATION: Avoid list(set) if the set is empty.
if cached_encoder_input_ids: if cached_encoder_input_ids:
for input_id in list(cached_encoder_input_ids): for input_id in list(cached_encoder_input_ids):
start_pos = request.mm_positions[input_id]["offset"] mm_positions = request.mm_positions[input_id]
num_tokens = request.mm_positions[input_id]["length"] start_pos = mm_positions["offset"]
num_tokens = mm_positions["length"]
if start_pos + num_tokens <= request.num_computed_tokens: if start_pos + num_tokens <= request.num_computed_tokens:
# The encoder output is already processed and stored # The encoder output is already processed and stored
# in the decoder's KV cache. # in the decoder's KV cache.
@ -616,25 +614,24 @@ class Scheduler(SchedulerInterface):
stopped = False stopped = False
new_logprobs = None new_logprobs = None
new_token_ids: list[int] = [] new_token_ids = generated_token_ids
# Append generated tokens and check for stop. Note that if # Append generated tokens and check for stop. Note that if
# a request is still being prefilled, we expect the model runner # a request is still being prefilled, we expect the model runner
# to return empty token ids for the request. # to return empty token ids for the request.
for output_token_id in generated_token_ids: for num_new, output_token_id in enumerate(new_token_ids, 1):
request.append_output_token_ids(output_token_id) request.append_output_token_ids(output_token_id)
new_token_ids.append(output_token_id)
# Check for stop and update request state. # Check for stop and update request state.
# This must be called before we make the EngineCoreOutput. # This must be called before we make the EngineCoreOutput.
stopped = check_stop(request, self.max_model_len) stopped = check_stop(request, self.max_model_len)
if stopped: if stopped:
self._free_request(request) self._free_request(request)
del new_token_ids[num_new:] # Trim new tokens if needed.
break break
# Extract sample logprobs if needed. # Extract sample logprobs if needed.
if (request.sampling_params.logprobs is not None if request.sampling_params.logprobs is not None and logprobs:
and logprobs is not None):
# NOTE: once we support N tokens per step (spec decode), # NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1. # the outer lists can be of length > 1.
new_logprobs = logprobs.slice(req_index, req_index + 1) new_logprobs = logprobs.slice(req_index, req_index + 1)
@ -644,9 +641,7 @@ class Scheduler(SchedulerInterface):
# should not be None if use_structured_output, we have # should not be None if use_structured_output, we have
# check above, so safe to ignore type warning # check above, so safe to ignore type warning
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
request.request_id, req_id, new_token_ids)
new_token_ids,
)
# Get prompt logprobs for this request. # Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
@ -665,7 +660,7 @@ class Scheduler(SchedulerInterface):
# Invariant: EngineCore returns no partial prefill outputs. # Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors assert not prompt_logprobs_tensors
self.scheduled_req_ids.remove(request.request_id) self.scheduled_req_ids.remove(req_id)
if not stopped: if not stopped:
new_running.append(request) new_running.append(request)

View File

@ -416,9 +416,9 @@ class SyncMPClient(MPClient):
def process_outputs_socket(): def process_outputs_socket():
shutdown_socket = ctx.socket(zmq.PAIR) shutdown_socket = ctx.socket(zmq.PAIR)
shutdown_socket.bind(shutdown_path)
out_socket = make_zmq_socket(ctx, output_path, zmq.constants.PULL) out_socket = make_zmq_socket(ctx, output_path, zmq.constants.PULL)
try: try:
shutdown_socket.bind(shutdown_path)
poller = zmq.Poller() poller = zmq.Poller()
poller.register(shutdown_socket) poller.register(shutdown_socket)
poller.register(out_socket) poller.register(out_socket)

View File

@ -328,7 +328,7 @@ class OutputProcessor:
# 2) Detokenize the token ids into text and perform stop checks. # 2) Detokenize the token ids into text and perform stop checks.
stop_string = req_state.detokenizer.update( stop_string = req_state.detokenizer.update(
new_token_ids, finish_reason == FinishReason.STOP) new_token_ids, finish_reason == FinishReason.STOP)
if stop_string and finish_reason != FinishReason.STOP: if stop_string:
finish_reason = FinishReason.STOP finish_reason = FinishReason.STOP
stop_reason = stop_string stop_reason = stop_string

View File

@ -93,7 +93,9 @@ class Request:
token_ids: Union[int, list[int]], token_ids: Union[int, list[int]],
) -> None: ) -> None:
if isinstance(token_ids, int): if isinstance(token_ids, int):
token_ids = [token_ids] self._output_token_ids.append(token_ids)
self._all_token_ids.append(token_ids)
else:
self._output_token_ids.extend(token_ids) self._output_token_ids.extend(token_ids)
self._all_token_ids.extend(token_ids) self._all_token_ids.extend(token_ids)