refactor: Turn GPUModelRunner.inputs_embeds to a CpuGpuBuffer (#24345)

Signed-off-by: Andrew Sansom <andrew@protopia.ai>
This commit is contained in:
Andrew Sansom 2025-09-06 01:01:23 -05:00 committed by GitHub
parent 6d6c6b05d3
commit 305a1cc0d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 37 additions and 14 deletions

View File

@ -19,6 +19,8 @@ from vllm.utils import (get_open_port, get_open_zmq_ipc_path, get_tcp_uri,
kill_process_tree)
if TYPE_CHECKING:
import numpy as np
from vllm.v1.engine.coordinator import DPCoordinator
from vllm.v1.engine.utils import (CoreEngineActorManager,
CoreEngineProcManager)
@ -97,20 +99,31 @@ class ConstantList(Generic[T], Sequence):
class CpuGpuBuffer:
"""Buffer to easily copy tensors between CPU and GPU."""
def __init__(
self,
*args,
*size: Union[int, torch.SymInt],
dtype: torch.dtype,
device: torch.device,
pin_memory: bool,
):
self.cpu = torch.zeros(*args,
with_numpy: bool = True,
) -> None:
self.cpu = torch.zeros(*size,
dtype=dtype,
device="cpu",
pin_memory=pin_memory)
self.np = self.cpu.numpy()
self.gpu = self.cpu.to(device)
self.np: np.ndarray
# To keep type hints simple (avoiding generics and subclasses), we
# only conditionally create the numpy array attribute. This can cause
# AttributeError if `self.np` is accessed when `with_numpy=False`.
if with_numpy:
if dtype == torch.bfloat16:
raise ValueError(
"Bfloat16 torch tensors cannot be directly cast to a "
"numpy array, so call CpuGpuBuffer with with_numpy=False")
self.np = self.cpu.numpy()
def copy_to_gpu(self, n: Optional[int] = None) -> torch.Tensor:
if n is None:

View File

@ -303,10 +303,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.query_start_loc = self._make_buffer(self.max_num_reqs + 1,
dtype=torch.int32)
self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=self.device)
# Because inputs_embeds may be bfloat16 and we don't need a numpy
# version of this tensor, avoid a RuntimeError by not creating a
# numpy buffer.
self.inputs_embeds = self._make_buffer(self.max_num_tokens,
self.hidden_size,
dtype=self.dtype,
numpy=False)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
@ -374,11 +377,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
device="cpu",
pin_memory=self.pin_memory)
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(*args,
def _make_buffer(self,
*size: Union[int, torch.SymInt],
dtype: torch.dtype,
numpy: bool = True) -> CpuGpuBuffer:
# Bfloat16 torch tensors cannot be directly cast to a numpy array, so
# if a bfloat16 buffer is needed without a corresponding numpy array,
# don't bother instantiating the numpy array.
return CpuGpuBuffer(*size,
dtype=dtype,
device=self.device,
pin_memory=self.pin_memory)
pin_memory=self.pin_memory,
with_numpy=numpy)
def _init_model_kwargs(self, num_tokens: int):
model_kwargs = dict[str, Any]()
@ -1645,11 +1655,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
# TODO(woosuk): Avoid the copy. Optimize.
self.inputs_embeds[:num_scheduled_tokens].copy_(
self.inputs_embeds.gpu[:num_scheduled_tokens].copy_(
inputs_embeds_scheduled)
input_ids = None
inputs_embeds = self.inputs_embeds[:num_input_tokens]
inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens]
model_kwargs = {
**self._init_model_kwargs(num_scheduled_tokens),
**self._extract_mm_kwargs(scheduler_output),
@ -2484,7 +2494,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_scheduled_tokens, remove_lora):
if self.supports_mm_inputs:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
inputs_embeds = self.inputs_embeds.gpu[:num_tokens]
model_kwargs = {
**self._init_model_kwargs(num_tokens),
**self._dummy_mm_kwargs(num_reqs),