mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-09 02:51:50 +08:00
refactor: Turn GPUModelRunner.inputs_embeds to a CpuGpuBuffer (#24345)
Signed-off-by: Andrew Sansom <andrew@protopia.ai>
This commit is contained in:
parent
6d6c6b05d3
commit
305a1cc0d2
@ -19,6 +19,8 @@ from vllm.utils import (get_open_port, get_open_zmq_ipc_path, get_tcp_uri,
|
||||
kill_process_tree)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
|
||||
from vllm.v1.engine.coordinator import DPCoordinator
|
||||
from vllm.v1.engine.utils import (CoreEngineActorManager,
|
||||
CoreEngineProcManager)
|
||||
@ -97,20 +99,31 @@ class ConstantList(Generic[T], Sequence):
|
||||
|
||||
|
||||
class CpuGpuBuffer:
|
||||
"""Buffer to easily copy tensors between CPU and GPU."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
*size: Union[int, torch.SymInt],
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
):
|
||||
self.cpu = torch.zeros(*args,
|
||||
with_numpy: bool = True,
|
||||
) -> None:
|
||||
self.cpu = torch.zeros(*size,
|
||||
dtype=dtype,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.np = self.cpu.numpy()
|
||||
self.gpu = self.cpu.to(device)
|
||||
self.np: np.ndarray
|
||||
# To keep type hints simple (avoiding generics and subclasses), we
|
||||
# only conditionally create the numpy array attribute. This can cause
|
||||
# AttributeError if `self.np` is accessed when `with_numpy=False`.
|
||||
if with_numpy:
|
||||
if dtype == torch.bfloat16:
|
||||
raise ValueError(
|
||||
"Bfloat16 torch tensors cannot be directly cast to a "
|
||||
"numpy array, so call CpuGpuBuffer with with_numpy=False")
|
||||
self.np = self.cpu.numpy()
|
||||
|
||||
def copy_to_gpu(self, n: Optional[int] = None) -> torch.Tensor:
|
||||
if n is None:
|
||||
|
||||
@ -303,10 +303,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.query_start_loc = self._make_buffer(self.max_num_reqs + 1,
|
||||
dtype=torch.int32)
|
||||
self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
|
||||
self.inputs_embeds = torch.zeros(
|
||||
(self.max_num_tokens, self.hidden_size),
|
||||
dtype=self.dtype,
|
||||
device=self.device)
|
||||
# Because inputs_embeds may be bfloat16 and we don't need a numpy
|
||||
# version of this tensor, avoid a RuntimeError by not creating a
|
||||
# numpy buffer.
|
||||
self.inputs_embeds = self._make_buffer(self.max_num_tokens,
|
||||
self.hidden_size,
|
||||
dtype=self.dtype,
|
||||
numpy=False)
|
||||
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
if self.uses_mrope:
|
||||
@ -374,11 +377,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
|
||||
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
|
||||
return CpuGpuBuffer(*args,
|
||||
def _make_buffer(self,
|
||||
*size: Union[int, torch.SymInt],
|
||||
dtype: torch.dtype,
|
||||
numpy: bool = True) -> CpuGpuBuffer:
|
||||
# Bfloat16 torch tensors cannot be directly cast to a numpy array, so
|
||||
# if a bfloat16 buffer is needed without a corresponding numpy array,
|
||||
# don't bother instantiating the numpy array.
|
||||
return CpuGpuBuffer(*size,
|
||||
dtype=dtype,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory)
|
||||
pin_memory=self.pin_memory,
|
||||
with_numpy=numpy)
|
||||
|
||||
def _init_model_kwargs(self, num_tokens: int):
|
||||
model_kwargs = dict[str, Any]()
|
||||
@ -1645,11 +1655,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
)
|
||||
|
||||
# TODO(woosuk): Avoid the copy. Optimize.
|
||||
self.inputs_embeds[:num_scheduled_tokens].copy_(
|
||||
self.inputs_embeds.gpu[:num_scheduled_tokens].copy_(
|
||||
inputs_embeds_scheduled)
|
||||
|
||||
input_ids = None
|
||||
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
||||
inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens]
|
||||
model_kwargs = {
|
||||
**self._init_model_kwargs(num_scheduled_tokens),
|
||||
**self._extract_mm_kwargs(scheduler_output),
|
||||
@ -2484,7 +2494,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_scheduled_tokens, remove_lora):
|
||||
if self.supports_mm_inputs:
|
||||
input_ids = None
|
||||
inputs_embeds = self.inputs_embeds[:num_tokens]
|
||||
inputs_embeds = self.inputs_embeds.gpu[:num_tokens]
|
||||
model_kwargs = {
|
||||
**self._init_model_kwargs(num_tokens),
|
||||
**self._dummy_mm_kwargs(num_reqs),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user