diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 4206a24465e2..cae1a25519b3 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -207,10 +207,7 @@ class StatelessProcessGroup: def barrier(self): """A barrier to synchronize all ranks.""" for i in range(self.world_size): - if i == self.rank: - self.broadcast_obj(None, src=self.rank) - else: - self.broadcast_obj(None, src=i) + self.broadcast_obj(None, src=i) @staticmethod def create( diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 448119761259..094602a8b732 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -269,29 +269,26 @@ class Scheduler(SchedulerInterface): request = self.waiting[0] - # Waiting request skipping logic - is_skipped = False # Skip request if the structured output request is still waiting - # for FSM. - if (not is_skipped - and request.status == RequestStatus.WAITING_FOR_FSM): + # for FSM compilation. + if request.status == RequestStatus.WAITING_FOR_FSM: structured_output_req = request.structured_output_request - is_skipped = (not structured_output_req - or not structured_output_req.grammar) - if not is_skipped: + if structured_output_req and structured_output_req.grammar: request.status = RequestStatus.WAITING + else: + self.waiting.popleft() + skipped_waiting_requests.appendleft(request) + continue - # Skip request if max_loras can't be honored. - if (not is_skipped and self.lora_config - and request.lora_request): - req_lora_id = request.lora_request.lora_int_id - is_skipped = (len(scheduled_loras) - == self.lora_config.max_loras - and (req_lora_id not in scheduled_loras)) - - if is_skipped: - skipped_waiting_requests.appendleft(request) + # Check that adding the request still respects the max_loras + # constraint. + if self.lora_config and request.lora_request and ( + len(scheduled_loras) == self.lora_config.max_loras + and request.lora_request.lora_int_id + not in scheduled_loras): + # Scheduling would exceed max_loras, skip. self.waiting.popleft() + skipped_waiting_requests.appendleft(request) continue # Get already-cached tokens. @@ -602,8 +599,9 @@ class Scheduler(SchedulerInterface): # OPTIMIZATION: Avoid list(set) if the set is empty. if cached_encoder_input_ids: for input_id in list(cached_encoder_input_ids): - start_pos = request.mm_positions[input_id]["offset"] - num_tokens = request.mm_positions[input_id]["length"] + mm_positions = request.mm_positions[input_id] + start_pos = mm_positions["offset"] + num_tokens = mm_positions["length"] if start_pos + num_tokens <= request.num_computed_tokens: # The encoder output is already processed and stored # in the decoder's KV cache. @@ -616,25 +614,24 @@ class Scheduler(SchedulerInterface): stopped = False new_logprobs = None - new_token_ids: list[int] = [] + new_token_ids = generated_token_ids # Append generated tokens and check for stop. Note that if # a request is still being prefilled, we expect the model runner # to return empty token ids for the request. - for output_token_id in generated_token_ids: + for num_new, output_token_id in enumerate(new_token_ids, 1): request.append_output_token_ids(output_token_id) - new_token_ids.append(output_token_id) # Check for stop and update request state. # This must be called before we make the EngineCoreOutput. stopped = check_stop(request, self.max_model_len) if stopped: self._free_request(request) + del new_token_ids[num_new:] # Trim new tokens if needed. break # Extract sample logprobs if needed. - if (request.sampling_params.logprobs is not None - and logprobs is not None): + if request.sampling_params.logprobs is not None and logprobs: # NOTE: once we support N tokens per step (spec decode), # the outer lists can be of length > 1. new_logprobs = logprobs.slice(req_index, req_index + 1) @@ -644,9 +641,7 @@ class Scheduler(SchedulerInterface): # should not be None if use_structured_output, we have # check above, so safe to ignore type warning request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] - request.request_id, - new_token_ids, - ) + req_id, new_token_ids) # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) @@ -665,7 +660,7 @@ class Scheduler(SchedulerInterface): # Invariant: EngineCore returns no partial prefill outputs. assert not prompt_logprobs_tensors - self.scheduled_req_ids.remove(request.request_id) + self.scheduled_req_ids.remove(req_id) if not stopped: new_running.append(request) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index c41ee6704be0..8858a564d2c2 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -416,9 +416,9 @@ class SyncMPClient(MPClient): def process_outputs_socket(): shutdown_socket = ctx.socket(zmq.PAIR) - shutdown_socket.bind(shutdown_path) out_socket = make_zmq_socket(ctx, output_path, zmq.constants.PULL) try: + shutdown_socket.bind(shutdown_path) poller = zmq.Poller() poller.register(shutdown_socket) poller.register(out_socket) diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 1e67bed26118..70f072d3c939 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -328,7 +328,7 @@ class OutputProcessor: # 2) Detokenize the token ids into text and perform stop checks. stop_string = req_state.detokenizer.update( new_token_ids, finish_reason == FinishReason.STOP) - if stop_string and finish_reason != FinishReason.STOP: + if stop_string: finish_reason = FinishReason.STOP stop_reason = stop_string diff --git a/vllm/v1/request.py b/vllm/v1/request.py index efb5a54d1207..48e5132678c1 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -93,9 +93,11 @@ class Request: token_ids: Union[int, list[int]], ) -> None: if isinstance(token_ids, int): - token_ids = [token_ids] - self._output_token_ids.extend(token_ids) - self._all_token_ids.extend(token_ids) + self._output_token_ids.append(token_ids) + self._all_token_ids.append(token_ids) + else: + self._output_token_ids.extend(token_ids) + self._all_token_ids.extend(token_ids) @property def num_tokens(self) -> int: