diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index a24307b79d95e..7518cd8fc897f 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -278,6 +278,11 @@ class InputPreprocessor: raise ValueError( "prompt_embeds must be of shape (seq_len, hidden_size).") + # Tensors must be on CPU for serialization between processes + # in the MsgpackEncoder. Casting to CPU here ensures that there is no + # hidden device transfer in the critical path of generation. + prompt_embeds = prompt_embeds.cpu() + return embeds_inputs(prompt_embeds=prompt_embeds, cache_salt=parsed_content.get("cache_salt")) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index c812a2ec6427a..876838084b9aa 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -208,7 +208,7 @@ class MsgpackEncoder: ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]: assert self.aux_buffers is not None # view the tensor as a contiguous 1D array of bytes - arr = obj.flatten().contiguous().cpu().view(torch.uint8).numpy() + arr = obj.flatten().contiguous().view(torch.uint8).numpy() if obj.nbytes < self.size_threshold: # Smaller tensors are encoded inline, just like ndarrays. data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data)