mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-30 14:11:48 +08:00
[Misc] Move CpuGpuBuffer to vllm/v1/utils.py (#23728)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
6578e87365
commit
04ff1e43fb
@ -96,6 +96,35 @@ class ConstantList(Generic[T], Sequence):
|
||||
return f"ConstantList({self._x})"
|
||||
|
||||
|
||||
class CpuGpuBuffer:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
):
|
||||
self.cpu = torch.zeros(*args,
|
||||
dtype=dtype,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.np = self.cpu.numpy()
|
||||
self.gpu = self.cpu.to(device)
|
||||
|
||||
def copy_to_gpu(self, n: Optional[int] = None) -> torch.Tensor:
|
||||
if n is None:
|
||||
return self.gpu.copy_(self.cpu, non_blocking=True)
|
||||
return self.gpu[:n].copy_(self.cpu[:n], non_blocking=True)
|
||||
|
||||
def copy_to_cpu(self, n: Optional[int] = None) -> torch.Tensor:
|
||||
"""NOTE: Because this method is non-blocking, explicit synchronization
|
||||
is needed to ensure the data is copied to CPU."""
|
||||
if n is None:
|
||||
return self.cpu.copy_(self.gpu, non_blocking=True)
|
||||
return self.cpu[:n].copy_(self.gpu[:n], non_blocking=True)
|
||||
|
||||
|
||||
def get_engine_client_zmq_addr(local_only: bool,
|
||||
host: str,
|
||||
port: int = 0) -> str:
|
||||
|
||||
@ -10,8 +10,8 @@ from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.v1.attention.backends.cpu_attn import TorchSDPAMetadataBuilderV1
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
from vllm.v1.worker.utils import CpuGpuBuffer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
@ -78,14 +78,14 @@ from vllm.v1.spec_decode.eagle import EagleProposer
|
||||
from vllm.v1.spec_decode.medusa import MedusaProposer
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||
KVConnectorModelRunnerMixin, KVConnectorOutput)
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
|
||||
from .utils import (AttentionGroup, CpuGpuBuffer, MultiModalBudget,
|
||||
bind_kv_cache, gather_mm_placeholders,
|
||||
initialize_kv_cache_for_kv_sharing,
|
||||
from .utils import (AttentionGroup, MultiModalBudget, bind_kv_cache,
|
||||
gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
|
||||
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@ -303,32 +303,3 @@ def bind_kv_cache(
|
||||
for layer_name, kv_cache in kv_caches.items():
|
||||
# NOTE: Use list because of v0 PP virtual engine.
|
||||
forward_context[layer_name].kv_cache = [kv_cache]
|
||||
|
||||
|
||||
class CpuGpuBuffer:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
):
|
||||
self.cpu = torch.zeros(*args,
|
||||
dtype=dtype,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.np = self.cpu.numpy()
|
||||
self.gpu = self.cpu.to(device)
|
||||
|
||||
def copy_to_gpu(self, n: Optional[int] = None) -> None:
|
||||
if n is None:
|
||||
return self.gpu.copy_(self.cpu, non_blocking=True)
|
||||
return self.gpu[:n].copy_(self.cpu[:n], non_blocking=True)
|
||||
|
||||
def copy_to_cpu(self, n: Optional[int] = None) -> None:
|
||||
"""NOTE: Because this method is non-blocking, explicit synchronization
|
||||
is needed to ensure the data is copied to CPU."""
|
||||
if n is None:
|
||||
return self.cpu.copy_(self.gpu, non_blocking=True)
|
||||
return self.cpu[:n].copy_(self.gpu[:n], non_blocking=True)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user