mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:05:01 +08:00
[Misc]Use a platform independent interface to obtain the device attributes (#17100)
This commit is contained in:
parent
ebb3930d28
commit
bdb2cddafc
@ -293,7 +293,8 @@ class HfRunner:
|
|||||||
def get_default_device(self):
|
def get_default_device(self):
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
return ("cpu" if current_platform.is_cpu() else "cuda")
|
return ("cpu"
|
||||||
|
if current_platform.is_cpu() else current_platform.device_type)
|
||||||
|
|
||||||
def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
|
def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
|
||||||
if x is None or isinstance(x, (bool, )):
|
if x is None or isinstance(x, (bool, )):
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import make_tensor_with_pad
|
from vllm.utils import make_tensor_with_pad
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.sample.sampler import Sampler
|
from vllm.v1.sample.sampler import Sampler
|
||||||
@ -13,7 +14,8 @@ from vllm.v1.sample.sampler import Sampler
|
|||||||
VOCAB_SIZE = 1024
|
VOCAB_SIZE = 1024
|
||||||
NUM_OUTPUT_TOKENS = 20
|
NUM_OUTPUT_TOKENS = 20
|
||||||
CUDA_DEVICES = [
|
CUDA_DEVICES = [
|
||||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
f"{current_platform.device_type}:{i}"
|
||||||
|
for i in range(1 if current_platform.device_count() == 1 else 2)
|
||||||
]
|
]
|
||||||
MAX_NUM_PROMPT_TOKENS = 64
|
MAX_NUM_PROMPT_TOKENS = 64
|
||||||
|
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs,
|
|||||||
SamplerOutput,
|
SamplerOutput,
|
||||||
SamplingMetadata, get_logprobs,
|
SamplingMetadata, get_logprobs,
|
||||||
get_pythonized_sample_results)
|
get_pythonized_sample_results)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
|
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
|
||||||
Logprob, SequenceGroupMetadata, SequenceOutput)
|
Logprob, SequenceGroupMetadata, SequenceOutput)
|
||||||
from vllm.utils import PyObjectCache, async_tensor_h2d, current_stream
|
from vllm.utils import PyObjectCache, async_tensor_h2d, current_stream
|
||||||
@ -158,8 +159,8 @@ class StatefulModelInput(BroadcastableModelInput):
|
|||||||
is_first_multi_step: bool = False
|
is_first_multi_step: bool = False
|
||||||
base_output_proc_callback: Optional[Callable] = None
|
base_output_proc_callback: Optional[Callable] = None
|
||||||
# ping-pong data structures for multi-step to wait on the previous step
|
# ping-pong data structures for multi-step to wait on the previous step
|
||||||
step_cuda_events: List[torch.cuda.Event] = field(
|
step_cuda_events: List[current_platform.Event] = field(
|
||||||
default_factory=lambda: [torch.cuda.Event(blocking=True)] * 2)
|
default_factory=lambda: [current_platform.Event(blocking=True)] * 2)
|
||||||
num_seqs: int = -1
|
num_seqs: int = -1
|
||||||
num_queries: int = -1
|
num_queries: int = -1
|
||||||
num_single_step_prefills: int = 0
|
num_single_step_prefills: int = 0
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user