diff --git a/vllm/distributed/device_communicators/shm_object_storage.py b/vllm/distributed/device_communicators/shm_object_storage.py index 2ec33afb87839..4af2caa16b0d6 100644 --- a/vllm/distributed/device_communicators/shm_object_storage.py +++ b/vllm/distributed/device_communicators/shm_object_storage.py @@ -342,8 +342,8 @@ class MsgpackSerde(ObjectSerde): from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder self.encoder = MsgpackEncoder() - self.tensor_decoder = MsgpackDecoder(torch.Tensor) - self.mm_decoder = MsgpackDecoder(MultiModalKwargsItem) + self.tensor_decoder = MsgpackDecoder(torch.Tensor, share_mem=False) + self.mm_decoder = MsgpackDecoder(MultiModalKwargsItem, share_mem=False) self._mm_kwargs_item_cls = MultiModalKwargsItem def serialize(self, value: Any) -> tuple[bytes | list[bytes], int, bytes, int]: @@ -368,7 +368,7 @@ class MsgpackSerde(ObjectSerde): # pickle.loads do not read past the end of a pickled object # within a large buffer, so we can skip storing the metadata size type_name, nbytes, len_arr = pickle.loads(data_view) - serialized_data = bytearray(data_view[-nbytes:]) + serialized_data = data_view[-nbytes:] if type_name == torch.Tensor.__name__: obj = [] diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 102357ca7c642..cf0b1a41b50f8 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -31,6 +31,7 @@ from vllm.multimodal.inputs import ( MultiModalSharedField, NestedTensors, ) +from vllm.utils.platform_utils import is_pin_memory_available from vllm.v1.engine import UtilityResult from vllm.v1.utils import tensor_data @@ -282,7 +283,9 @@ class MsgpackDecoder: not thread-safe when encoding tensors / numpy arrays. """ - def __init__(self, t: Any | None = None): + def __init__(self, t: Any | None = None, share_mem: bool = True): + self.share_mem = share_mem + self.pin_tensors = is_pin_memory_available() args = () if t is None else (t,) self.decoder = msgpack.Decoder( *args, ext_hook=self.ext_hook, dec_hook=self.dec_hook @@ -347,21 +350,30 @@ class MsgpackDecoder: # zero-copy decode. We assume the ndarray will not be kept around, # as it now locks the whole received message buffer in memory. buffer = self.aux_buffers[data] if isinstance(data, int) else data - return np.frombuffer(buffer, dtype=dtype).reshape(shape) + arr = np.frombuffer(buffer, dtype=dtype) + if not self.share_mem: + arr = arr.copy() + return arr.reshape(shape) def _decode_tensor(self, arr: Any) -> torch.Tensor: dtype, shape, data = arr - # Copy from inline representation, to decouple the memory storage - # of the message from the original buffer. And also make Torch - # not complain about a readonly memoryview. - buffer = self.aux_buffers[data] if isinstance(data, int) else bytearray(data) + is_aux = isinstance(data, int) + buffer = self.aux_buffers[data] if is_aux else data + buffer = buffer if isinstance(buffer, memoryview) else memoryview(buffer) torch_dtype = getattr(torch, dtype) assert isinstance(torch_dtype, torch.dtype) - if not buffer: # torch.frombuffer doesn't like empty buffers + if not buffer.nbytes: # torch.frombuffer doesn't like empty buffers assert 0 in shape return torch.empty(shape, dtype=torch_dtype) # Create uint8 array arr = torch.frombuffer(buffer, dtype=torch.uint8) + # Clone ensures tensor is backed by pytorch-owned memory for safe + # future async CPU->GPU transfer. + # Pin larger tensors for more efficient CPU->GPU transfer. + if not is_aux: + arr = arr.clone() + elif not self.share_mem: + arr = arr.pin_memory() if self.pin_tensors else arr.clone() # Convert back to proper shape & type return arr.view(torch_dtype).view(shape) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c9c64137ca04b..d0f7f3a501f59 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2590,28 +2590,28 @@ class GPUModelRunner( ) ) - dp_rank = self.parallel_config.data_parallel_rank - if ubatch_slices: - assert num_tokens_across_dp is not None - num_input_tokens = int(num_tokens_across_dp[dp_rank].item()) - self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens) - elif num_tokens_across_dp is not None: - num_input_tokens = int(num_tokens_across_dp[dp_rank].item()) - else: - num_input_tokens = self._get_num_input_tokens( - scheduler_output.total_num_scheduled_tokens - ) + dp_rank = self.parallel_config.data_parallel_rank + if ubatch_slices: + assert num_tokens_across_dp is not None + num_input_tokens = int(num_tokens_across_dp[dp_rank].item()) + self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens) + elif num_tokens_across_dp is not None: + num_input_tokens = int(num_tokens_across_dp[dp_rank].item()) + else: + num_input_tokens = self._get_num_input_tokens( + scheduler_output.total_num_scheduled_tokens + ) - ( - input_ids, - inputs_embeds, - positions, - intermediate_tensors, - model_kwargs, - ec_connector_output, - ) = self._preprocess( - scheduler_output, num_input_tokens, intermediate_tensors - ) + ( + input_ids, + inputs_embeds, + positions, + intermediate_tensors, + model_kwargs, + ec_connector_output, + ) = self._preprocess( + scheduler_output, num_input_tokens, intermediate_tensors + ) uniform_decode = ( max_num_scheduled_tokens == self.uniform_decode_query_len