mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 06:55:01 +08:00
[TPU] Deprecate xm.mark_step in favor of `torch_xla.sync (#25254)
Signed-off-by: NickLucche <nlucches@redhat.com> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
parent
a66d131381
commit
4cf71cc88a
@ -6,6 +6,7 @@ Run `pytest tests/kernels/moe/test_moe_pallas.py`.
|
|||||||
"""
|
"""
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
import torch_xla
|
||||||
|
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
@ -77,7 +78,7 @@ def test_pallas_moe(
|
|||||||
expert_map=e_map,
|
expert_map=e_map,
|
||||||
renormalize=False,
|
renormalize=False,
|
||||||
)
|
)
|
||||||
xm.mark_step()
|
torch_xla.sync(wait=False)
|
||||||
|
|
||||||
# Compare outputs
|
# Compare outputs
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import math
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
import torch_xla
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||||
@ -63,7 +64,7 @@ def test_topp_result_sums_past_p():
|
|||||||
probs.masked_fill_(logits_masked.isinf(), 0)
|
probs.masked_fill_(logits_masked.isinf(), 0)
|
||||||
masked_prob_sum = probs.sum(dim=-1)
|
masked_prob_sum = probs.sum(dim=-1)
|
||||||
|
|
||||||
xm.mark_step()
|
torch_xla.sync()
|
||||||
|
|
||||||
# Perform assertion on CPU.
|
# Perform assertion on CPU.
|
||||||
assert torch.all(torch.ge(masked_prob_sum.cpu() + TOLERANCE, p.cpu()))
|
assert torch.all(torch.ge(masked_prob_sum.cpu() + TOLERANCE, p.cpu()))
|
||||||
@ -82,7 +83,7 @@ def test_topp_basic():
|
|||||||
k=torch.tensor([3, 3]),
|
k=torch.tensor([3, 3]),
|
||||||
p=torch.tensor([0.79, 0.79]))
|
p=torch.tensor([0.79, 0.79]))
|
||||||
|
|
||||||
xm.mark_step()
|
torch_xla.sync()
|
||||||
|
|
||||||
# Expect the smallest elements to be dropped.
|
# Expect the smallest elements to be dropped.
|
||||||
expected_result = logits.clone().cpu()
|
expected_result = logits.clone().cpu()
|
||||||
@ -104,7 +105,7 @@ def test_topp_select_all():
|
|||||||
k=torch.tensor([3, 3]),
|
k=torch.tensor([3, 3]),
|
||||||
p=torch.tensor([1.0, 1.0]))
|
p=torch.tensor([1.0, 1.0]))
|
||||||
|
|
||||||
xm.mark_step()
|
torch_xla.sync()
|
||||||
|
|
||||||
assert torch.allclose(logits.cpu(), result.cpu())
|
assert torch.allclose(logits.cpu(), result.cpu())
|
||||||
|
|
||||||
@ -122,7 +123,7 @@ def test_topp_with_ties():
|
|||||||
k=torch.tensor([4]),
|
k=torch.tensor([4]),
|
||||||
p=torch.tensor([0.2]))
|
p=torch.tensor([0.2]))
|
||||||
|
|
||||||
xm.mark_step()
|
torch_xla.sync()
|
||||||
|
|
||||||
# All tie values are included in the top-p set. Tie breaking is left
|
# All tie values are included in the top-p set. Tie breaking is left
|
||||||
# to be done during final sampling (all tie tokens have equal
|
# to be done during final sampling (all tie tokens have equal
|
||||||
@ -146,7 +147,7 @@ def test_both_topk_topp():
|
|||||||
k=torch.tensor([1, 3]),
|
k=torch.tensor([1, 3]),
|
||||||
p=torch.tensor([0.79, 0.79]))
|
p=torch.tensor([0.79, 0.79]))
|
||||||
|
|
||||||
xm.mark_step()
|
torch_xla.sync()
|
||||||
|
|
||||||
# Since for the first batch k=1, expect only the largest element gets
|
# Since for the first batch k=1, expect only the largest element gets
|
||||||
# selected.
|
# selected.
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla
|
||||||
|
|
||||||
from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
|
from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
|
||||||
from vllm.lora.punica_wrapper.utils import convert_mapping
|
from vllm.lora.punica_wrapper.utils import convert_mapping
|
||||||
@ -323,7 +323,7 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
|||||||
extra_vocab_size: int,
|
extra_vocab_size: int,
|
||||||
):
|
):
|
||||||
# Make sure we don't accidentally collect outside operations
|
# Make sure we don't accidentally collect outside operations
|
||||||
xm.mark_step()
|
torch_xla.sync()
|
||||||
|
|
||||||
# Pad the prompt mapping to avoid running into recompiles on the TPU
|
# Pad the prompt mapping to avoid running into recompiles on the TPU
|
||||||
# TODO: Should this happen inside mapping internally? If so how can we
|
# TODO: Should this happen inside mapping internally? If so how can we
|
||||||
|
|||||||
@ -211,16 +211,15 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
from vllm.platforms.tpu import USE_TPU_COMMONS
|
from vllm.platforms.tpu import USE_TPU_COMMONS
|
||||||
|
|
||||||
if not USE_TPU_COMMONS:
|
if not USE_TPU_COMMONS:
|
||||||
# In PyTorch XLA, we should call `xm.mark_step`
|
# In PyTorch XLA, we should call `torch_xla.sync`
|
||||||
# frequently so that not too many ops are accumulated
|
# frequently so that not too many ops are accumulated
|
||||||
# in the XLA program. import torch_xla.core.xla_model
|
# in the XLA program.
|
||||||
# as xm
|
import torch_xla
|
||||||
import torch_xla.core.xla_model as xm
|
|
||||||
|
|
||||||
def _xla_weights_iterator(iterator: Generator):
|
def _xla_weights_iterator(iterator: Generator):
|
||||||
for weights in iterator:
|
for weights in iterator:
|
||||||
yield weights
|
yield weights
|
||||||
xm.mark_step()
|
torch_xla.sync(wait=False)
|
||||||
|
|
||||||
weights_iterator = _xla_weights_iterator(weights_iterator)
|
weights_iterator = _xla_weights_iterator(weights_iterator)
|
||||||
|
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
# TPU XLA related
|
# TPU XLA related
|
||||||
|
import torch_xla
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
import torch_xla.distributed.spmd as xs
|
import torch_xla.distributed.spmd as xs
|
||||||
import torch_xla.runtime as xr
|
import torch_xla.runtime as xr
|
||||||
@ -846,10 +847,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# 2. A list or tuple (length: num_items) of tensors, each of shape
|
# 2. A list or tuple (length: num_items) of tensors, each of shape
|
||||||
# (feature_size, hidden_size) in case the feature size is dynamic
|
# (feature_size, hidden_size) in case the feature size is dynamic
|
||||||
# depending on the input multimodal items.
|
# depending on the input multimodal items.
|
||||||
xm.mark_step()
|
torch_xla.sync(wait=False)
|
||||||
curr_group_outputs = self.model.get_multimodal_embeddings(
|
curr_group_outputs = self.model.get_multimodal_embeddings(
|
||||||
**mm_kwargs_group)
|
**mm_kwargs_group)
|
||||||
xm.mark_step()
|
torch_xla.sync(wait=False)
|
||||||
|
|
||||||
sanity_check_mm_encoder_outputs(
|
sanity_check_mm_encoder_outputs(
|
||||||
curr_group_outputs,
|
curr_group_outputs,
|
||||||
@ -952,7 +953,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
mm_embeds = self._gather_mm_embeddings(scheduler_output)
|
mm_embeds = self._gather_mm_embeddings(scheduler_output)
|
||||||
else:
|
else:
|
||||||
mm_embeds = []
|
mm_embeds = []
|
||||||
xm.mark_step()
|
torch_xla.sync(wait=False)
|
||||||
# Prepare inputs, the requests might be split into multiple
|
# Prepare inputs, the requests might be split into multiple
|
||||||
# executions, combine the result of each execution.
|
# executions, combine the result of each execution.
|
||||||
start_index = 0
|
start_index = 0
|
||||||
@ -969,7 +970,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
end_index = self._prepare_inputs(scheduler_output, start_index)
|
end_index = self._prepare_inputs(scheduler_output, start_index)
|
||||||
input_ids, inputs_embeds = self._get_model_inputs(
|
input_ids, inputs_embeds = self._get_model_inputs(
|
||||||
self.input_ids, mm_embeds)
|
self.input_ids, mm_embeds)
|
||||||
xm.mark_step()
|
torch_xla.sync(wait=False)
|
||||||
# Run the decoder
|
# Run the decoder
|
||||||
with set_forward_context(
|
with set_forward_context(
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
@ -1183,7 +1184,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
# Sync all pending XLA execution during model initialization and weight
|
# Sync all pending XLA execution during model initialization and weight
|
||||||
# loading.
|
# loading.
|
||||||
xm.mark_step()
|
torch_xla.sync(wait=False)
|
||||||
xm.wait_device_ops()
|
xm.wait_device_ops()
|
||||||
if not hasattr(self, "model"):
|
if not hasattr(self, "model"):
|
||||||
self.model = model
|
self.model = model
|
||||||
@ -1267,10 +1268,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
def _set_active_loras(self, prompt_lora_mapping, token_lora_mapping,
|
def _set_active_loras(self, prompt_lora_mapping, token_lora_mapping,
|
||||||
lora_requests) -> None:
|
lora_requests) -> None:
|
||||||
xm.mark_step() # Captures input updates
|
torch_xla.sync(wait=False) # Captures input updates
|
||||||
super()._set_active_loras(prompt_lora_mapping, token_lora_mapping,
|
super()._set_active_loras(prompt_lora_mapping, token_lora_mapping,
|
||||||
lora_requests)
|
lora_requests)
|
||||||
xm.mark_step() # Captures metadata updates
|
torch_xla.sync(wait=False) # Captures metadata updates
|
||||||
|
|
||||||
def _precompile_mm_encoder(self) -> None:
|
def _precompile_mm_encoder(self) -> None:
|
||||||
if not self.supports_mm_inputs:
|
if not self.supports_mm_inputs:
|
||||||
@ -1297,10 +1298,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
num_items,
|
num_items,
|
||||||
)
|
)
|
||||||
# Run multimodal encoder.
|
# Run multimodal encoder.
|
||||||
xm.mark_step()
|
torch_xla.sync(wait=False)
|
||||||
mm_embeds = self.model.get_multimodal_embeddings(
|
mm_embeds = self.model.get_multimodal_embeddings(
|
||||||
**batched_dummy_mm_inputs)
|
**batched_dummy_mm_inputs)
|
||||||
xm.mark_step()
|
torch_xla.sync(wait=False)
|
||||||
num_patches = mm_embeds[0].shape[0]
|
num_patches = mm_embeds[0].shape[0]
|
||||||
items_size = num_patches * num_items
|
items_size = num_patches * num_items
|
||||||
|
|
||||||
@ -1325,7 +1326,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
a, b = self._get_model_inputs(placeholders_ids,
|
a, b = self._get_model_inputs(placeholders_ids,
|
||||||
[mm_embeds])
|
[mm_embeds])
|
||||||
assert a is None
|
assert a is None
|
||||||
xm.mark_step()
|
torch_xla.sync(wait=False)
|
||||||
|
|
||||||
# Pre-compile `get_input_embeddings` when mm_embeddings are not
|
# Pre-compile `get_input_embeddings` when mm_embeddings are not
|
||||||
# present. Chunk is only made of text, no mm_placeholders.
|
# present. Chunk is only made of text, no mm_placeholders.
|
||||||
@ -1336,7 +1337,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
placeholders_ids = placeholders_ids.to(self.device)
|
placeholders_ids = placeholders_ids.to(self.device)
|
||||||
a, b = self._get_model_inputs(placeholders_ids, [])
|
a, b = self._get_model_inputs(placeholders_ids, [])
|
||||||
assert a is None
|
assert a is None
|
||||||
xm.mark_step()
|
torch_xla.sync(wait=False)
|
||||||
|
|
||||||
xm.wait_device_ops()
|
xm.wait_device_ops()
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
@ -1532,11 +1533,11 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# Isolate encoder graph from post-processing to minimize
|
# Isolate encoder graph from post-processing to minimize
|
||||||
# impact of recompilation until it's fixed.
|
# impact of recompilation until it's fixed.
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
xm.mark_step()
|
torch_xla.sync(wait=False)
|
||||||
dummy_encoder_outputs = \
|
dummy_encoder_outputs = \
|
||||||
self.model.get_multimodal_embeddings(
|
self.model.get_multimodal_embeddings(
|
||||||
**batched_dummy_mm_inputs)
|
**batched_dummy_mm_inputs)
|
||||||
xm.mark_step()
|
torch_xla.sync(wait=False)
|
||||||
xm.wait_device_ops()
|
xm.wait_device_ops()
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
logger.info(
|
logger.info(
|
||||||
@ -1559,7 +1560,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self._dummy_run(num_tokens, self.num_reqs_most_model_len,
|
self._dummy_run(num_tokens, self.num_reqs_most_model_len,
|
||||||
self.num_blocks_per_most_len_req)
|
self.num_blocks_per_most_len_req)
|
||||||
|
|
||||||
xm.mark_step()
|
torch_xla.sync(wait=False)
|
||||||
xm.wait_device_ops()
|
xm.wait_device_ops()
|
||||||
self.encoder_cache.clear()
|
self.encoder_cache.clear()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
@ -1927,11 +1928,11 @@ def replace_set_lora(model):
|
|||||||
# to a tensor doesn't seem to work anymore. This might be fixed with a
|
# to a tensor doesn't seem to work anymore. This might be fixed with a
|
||||||
# later release of torch_xla.
|
# later release of torch_xla.
|
||||||
self._original_set_lora(index, lora_a, lora_b, embeddings_tensor, bias)
|
self._original_set_lora(index, lora_a, lora_b, embeddings_tensor, bias)
|
||||||
xm.mark_step()
|
torch_xla.sync(wait=False)
|
||||||
|
|
||||||
def _tpu_reset_lora(self, index: int):
|
def _tpu_reset_lora(self, index: int):
|
||||||
self._original_reset_lora(index)
|
self._original_reset_lora(index)
|
||||||
xm.mark_step()
|
torch_xla.sync(wait=False)
|
||||||
|
|
||||||
for _, module in model.named_modules():
|
for _, module in model.named_modules():
|
||||||
if isinstance(module, BaseLayerWithLoRA):
|
if isinstance(module, BaseLayerWithLoRA):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user