mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:34:27 +08:00
[BugFix]: Batch generation from prompt_embeds fails for long prompts (#21390)
Signed-off-by: KazusatoOko <kazusto.oko@sakana.ai> Co-authored-by: KazusatoOko <kazusto.oko@sakana.ai>
This commit is contained in:
parent
f8c15c4efb
commit
fd48d99ffd
@ -1785,24 +1785,32 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
|
||||
if model_input.inputs_embeds is not None:
|
||||
if self.is_driver_worker:
|
||||
sampled = broadcast_tensor_dict(
|
||||
{"token_ids": output.sampled_token_ids})
|
||||
sampled_token_ids = []
|
||||
valid_outputs = []
|
||||
for sequence_group_output in output.outputs:
|
||||
if len(sequence_group_output.samples) == 0:
|
||||
continue
|
||||
assert len(sequence_group_output.samples) == 1
|
||||
valid_outputs.append(sequence_group_output)
|
||||
sampled_token_ids.append(
|
||||
sequence_group_output.samples[0].output_token)
|
||||
sampled_token_ids = torch.tensor(sampled_token_ids).to(
|
||||
self.device)
|
||||
sampled_token_ids = broadcast_tensor_dict(
|
||||
{"sampled_token_ids":
|
||||
sampled_token_ids})["sampled_token_ids"]
|
||||
else:
|
||||
sampled = broadcast_tensor_dict()
|
||||
if sampled["token_ids"] is not None:
|
||||
sampled_token_embeds = self.model.get_input_embeddings(
|
||||
sampled["token_ids"].squeeze(1))
|
||||
sampled_token_ids = broadcast_tensor_dict(
|
||||
)["sampled_token_ids"]
|
||||
if len(sampled_token_ids) > 0:
|
||||
sampled_token_embeds = \
|
||||
self.model.get_input_embeddings(sampled_token_ids)
|
||||
if self.is_driver_worker:
|
||||
self.sampler.include_gpu_probs_tensor = \
|
||||
orig_include_gpu_probs
|
||||
|
||||
output.sampled_token_embeds = sampled_token_embeds
|
||||
|
||||
for token_embed, sequence_group_output in zip(
|
||||
output.sampled_token_embeds, output.outputs):
|
||||
assert len(sequence_group_output.samples) == 1
|
||||
sequence_group_output.samples[
|
||||
0].output_embed = token_embed
|
||||
for i, sequence_group_output in enumerate(valid_outputs):
|
||||
sequence_group_output.samples[0].output_embed = \
|
||||
sampled_token_embeds[i]
|
||||
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user