diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 4bea37c853031..5a185e7451ade 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1785,24 +1785,32 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): if model_input.inputs_embeds is not None: if self.is_driver_worker: - sampled = broadcast_tensor_dict( - {"token_ids": output.sampled_token_ids}) + sampled_token_ids = [] + valid_outputs = [] + for sequence_group_output in output.outputs: + if len(sequence_group_output.samples) == 0: + continue + assert len(sequence_group_output.samples) == 1 + valid_outputs.append(sequence_group_output) + sampled_token_ids.append( + sequence_group_output.samples[0].output_token) + sampled_token_ids = torch.tensor(sampled_token_ids).to( + self.device) + sampled_token_ids = broadcast_tensor_dict( + {"sampled_token_ids": + sampled_token_ids})["sampled_token_ids"] else: - sampled = broadcast_tensor_dict() - if sampled["token_ids"] is not None: - sampled_token_embeds = self.model.get_input_embeddings( - sampled["token_ids"].squeeze(1)) + sampled_token_ids = broadcast_tensor_dict( + )["sampled_token_ids"] + if len(sampled_token_ids) > 0: + sampled_token_embeds = \ + self.model.get_input_embeddings(sampled_token_ids) if self.is_driver_worker: self.sampler.include_gpu_probs_tensor = \ orig_include_gpu_probs - - output.sampled_token_embeds = sampled_token_embeds - - for token_embed, sequence_group_output in zip( - output.sampled_token_embeds, output.outputs): - assert len(sequence_group_output.samples) == 1 - sequence_group_output.samples[ - 0].output_embed = token_embed + for i, sequence_group_output in enumerate(valid_outputs): + sequence_group_output.samples[0].output_embed = \ + sampled_token_embeds[i] if not self.is_driver_worker: return []