mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-05 23:39:12 +08:00
[Misc] Replace cuda hard code with current_platform (#16983)
Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
parent
4be2255c81
commit
9c1baa5bc6
@ -1221,8 +1221,9 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
|
|||||||
ray.shutdown()
|
ray.shutdown()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
if not current_platform.is_cpu():
|
empty_cache = current_platform.empty_cache
|
||||||
torch.cuda.empty_cache()
|
if empty_cache is not None:
|
||||||
|
empty_cache()
|
||||||
try:
|
try:
|
||||||
torch._C._host_emptyCache()
|
torch._C._host_emptyCache()
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
|
|||||||
@ -120,7 +120,10 @@ def set_forward_context(attn_metadata: Any,
|
|||||||
# we use synchronous scheduling right now,
|
# we use synchronous scheduling right now,
|
||||||
# adding a sync point here should not affect
|
# adding a sync point here should not affect
|
||||||
# scheduling of the next batch
|
# scheduling of the next batch
|
||||||
torch.cuda.synchronize()
|
from vllm.platforms import current_platform
|
||||||
|
synchronize = current_platform.synchronize
|
||||||
|
if synchronize is not None:
|
||||||
|
synchronize()
|
||||||
now = time.perf_counter()
|
now = time.perf_counter()
|
||||||
# time measurement is in milliseconds
|
# time measurement is in milliseconds
|
||||||
batchsize_forward_time[batchsize].append(
|
batchsize_forward_time[batchsize].append(
|
||||||
|
|||||||
@ -126,12 +126,12 @@ class AsyncMetricsCollector:
|
|||||||
"""Copy rejection/typical-acceptance sampling metrics
|
"""Copy rejection/typical-acceptance sampling metrics
|
||||||
(number of accepted tokens, etc) to CPU asynchronously.
|
(number of accepted tokens, etc) to CPU asynchronously.
|
||||||
|
|
||||||
Returns a CUDA event recording when the copy is complete.
|
Returns a device event recording when the copy is complete.
|
||||||
"""
|
"""
|
||||||
assert self._copy_stream is not None
|
assert self._copy_stream is not None
|
||||||
self._copy_stream.wait_stream(torch.cuda.current_stream())
|
self._copy_stream.wait_stream(current_platform.current_stream())
|
||||||
|
|
||||||
with torch.cuda.stream(self._copy_stream):
|
with current_platform.stream(self._copy_stream):
|
||||||
self._aggregate_num_accepted_tokens.copy_(
|
self._aggregate_num_accepted_tokens.copy_(
|
||||||
self.spec_decode_sampler.num_accepted_tokens,
|
self.spec_decode_sampler.num_accepted_tokens,
|
||||||
non_blocking=True)
|
non_blocking=True)
|
||||||
@ -142,7 +142,7 @@ class AsyncMetricsCollector:
|
|||||||
self._aggregate_num_draft_tokens = (
|
self._aggregate_num_draft_tokens = (
|
||||||
self.spec_decode_sampler.num_draft_tokens)
|
self.spec_decode_sampler.num_draft_tokens)
|
||||||
|
|
||||||
aggregate_metrics_ready = torch.cuda.Event()
|
aggregate_metrics_ready = current_platform.Event()
|
||||||
aggregate_metrics_ready.record(self._copy_stream)
|
aggregate_metrics_ready.record(self._copy_stream)
|
||||||
|
|
||||||
return aggregate_metrics_ready
|
return aggregate_metrics_ready
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user