[TPU] Deprecate xm.mark_step in favor of `torch_xla.sync (#25254)

Signed-off-by: NickLucche <nlucches@redhat.com>
Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
Nicolò Lucchesi 2025-09-22 12:12:57 +02:00 committed by GitHub
parent a66d131381
commit 4cf71cc88a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 31 additions and 29 deletions

View File

@ -6,6 +6,7 @@ Run `pytest tests/kernels/moe/test_moe_pallas.py`.
""" """
import pytest import pytest
import torch import torch
import torch_xla
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
@ -77,7 +78,7 @@ def test_pallas_moe(
expert_map=e_map, expert_map=e_map,
renormalize=False, renormalize=False,
) )
xm.mark_step() torch_xla.sync(wait=False)
# Compare outputs # Compare outputs
torch.testing.assert_close( torch.testing.assert_close(

View File

@ -4,6 +4,7 @@ import math
import pytest import pytest
import torch import torch
import torch_xla
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
@ -63,7 +64,7 @@ def test_topp_result_sums_past_p():
probs.masked_fill_(logits_masked.isinf(), 0) probs.masked_fill_(logits_masked.isinf(), 0)
masked_prob_sum = probs.sum(dim=-1) masked_prob_sum = probs.sum(dim=-1)
xm.mark_step() torch_xla.sync()
# Perform assertion on CPU. # Perform assertion on CPU.
assert torch.all(torch.ge(masked_prob_sum.cpu() + TOLERANCE, p.cpu())) assert torch.all(torch.ge(masked_prob_sum.cpu() + TOLERANCE, p.cpu()))
@ -82,7 +83,7 @@ def test_topp_basic():
k=torch.tensor([3, 3]), k=torch.tensor([3, 3]),
p=torch.tensor([0.79, 0.79])) p=torch.tensor([0.79, 0.79]))
xm.mark_step() torch_xla.sync()
# Expect the smallest elements to be dropped. # Expect the smallest elements to be dropped.
expected_result = logits.clone().cpu() expected_result = logits.clone().cpu()
@ -104,7 +105,7 @@ def test_topp_select_all():
k=torch.tensor([3, 3]), k=torch.tensor([3, 3]),
p=torch.tensor([1.0, 1.0])) p=torch.tensor([1.0, 1.0]))
xm.mark_step() torch_xla.sync()
assert torch.allclose(logits.cpu(), result.cpu()) assert torch.allclose(logits.cpu(), result.cpu())
@ -122,7 +123,7 @@ def test_topp_with_ties():
k=torch.tensor([4]), k=torch.tensor([4]),
p=torch.tensor([0.2])) p=torch.tensor([0.2]))
xm.mark_step() torch_xla.sync()
# All tie values are included in the top-p set. Tie breaking is left # All tie values are included in the top-p set. Tie breaking is left
# to be done during final sampling (all tie tokens have equal # to be done during final sampling (all tie tokens have equal
@ -146,7 +147,7 @@ def test_both_topk_topp():
k=torch.tensor([1, 3]), k=torch.tensor([1, 3]),
p=torch.tensor([0.79, 0.79])) p=torch.tensor([0.79, 0.79]))
xm.mark_step() torch_xla.sync()
# Since for the first batch k=1, expect only the largest element gets # Since for the first batch k=1, expect only the largest element gets
# selected. # selected.

View File

@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch_xla.core.xla_model as xm import torch_xla
from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
from vllm.lora.punica_wrapper.utils import convert_mapping from vllm.lora.punica_wrapper.utils import convert_mapping
@ -323,7 +323,7 @@ class PunicaWrapperTPU(PunicaWrapperBase):
extra_vocab_size: int, extra_vocab_size: int,
): ):
# Make sure we don't accidentally collect outside operations # Make sure we don't accidentally collect outside operations
xm.mark_step() torch_xla.sync()
# Pad the prompt mapping to avoid running into recompiles on the TPU # Pad the prompt mapping to avoid running into recompiles on the TPU
# TODO: Should this happen inside mapping internally? If so how can we # TODO: Should this happen inside mapping internally? If so how can we

View File

@ -211,16 +211,15 @@ class DefaultModelLoader(BaseModelLoader):
from vllm.platforms.tpu import USE_TPU_COMMONS from vllm.platforms.tpu import USE_TPU_COMMONS
if not USE_TPU_COMMONS: if not USE_TPU_COMMONS:
# In PyTorch XLA, we should call `xm.mark_step` # In PyTorch XLA, we should call `torch_xla.sync`
# frequently so that not too many ops are accumulated # frequently so that not too many ops are accumulated
# in the XLA program. import torch_xla.core.xla_model # in the XLA program.
# as xm import torch_xla
import torch_xla.core.xla_model as xm
def _xla_weights_iterator(iterator: Generator): def _xla_weights_iterator(iterator: Generator):
for weights in iterator: for weights in iterator:
yield weights yield weights
xm.mark_step() torch_xla.sync(wait=False)
weights_iterator = _xla_weights_iterator(weights_iterator) weights_iterator = _xla_weights_iterator(weights_iterator)

View File

@ -10,6 +10,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
# TPU XLA related # TPU XLA related
import torch_xla
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr import torch_xla.runtime as xr
@ -846,10 +847,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# 2. A list or tuple (length: num_items) of tensors, each of shape # 2. A list or tuple (length: num_items) of tensors, each of shape
# (feature_size, hidden_size) in case the feature size is dynamic # (feature_size, hidden_size) in case the feature size is dynamic
# depending on the input multimodal items. # depending on the input multimodal items.
xm.mark_step() torch_xla.sync(wait=False)
curr_group_outputs = self.model.get_multimodal_embeddings( curr_group_outputs = self.model.get_multimodal_embeddings(
**mm_kwargs_group) **mm_kwargs_group)
xm.mark_step() torch_xla.sync(wait=False)
sanity_check_mm_encoder_outputs( sanity_check_mm_encoder_outputs(
curr_group_outputs, curr_group_outputs,
@ -952,7 +953,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mm_embeds = self._gather_mm_embeddings(scheduler_output) mm_embeds = self._gather_mm_embeddings(scheduler_output)
else: else:
mm_embeds = [] mm_embeds = []
xm.mark_step() torch_xla.sync(wait=False)
# Prepare inputs, the requests might be split into multiple # Prepare inputs, the requests might be split into multiple
# executions, combine the result of each execution. # executions, combine the result of each execution.
start_index = 0 start_index = 0
@ -969,7 +970,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
end_index = self._prepare_inputs(scheduler_output, start_index) end_index = self._prepare_inputs(scheduler_output, start_index)
input_ids, inputs_embeds = self._get_model_inputs( input_ids, inputs_embeds = self._get_model_inputs(
self.input_ids, mm_embeds) self.input_ids, mm_embeds)
xm.mark_step() torch_xla.sync(wait=False)
# Run the decoder # Run the decoder
with set_forward_context( with set_forward_context(
attn_metadata, attn_metadata,
@ -1183,7 +1184,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Sync all pending XLA execution during model initialization and weight # Sync all pending XLA execution during model initialization and weight
# loading. # loading.
xm.mark_step() torch_xla.sync(wait=False)
xm.wait_device_ops() xm.wait_device_ops()
if not hasattr(self, "model"): if not hasattr(self, "model"):
self.model = model self.model = model
@ -1267,10 +1268,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _set_active_loras(self, prompt_lora_mapping, token_lora_mapping, def _set_active_loras(self, prompt_lora_mapping, token_lora_mapping,
lora_requests) -> None: lora_requests) -> None:
xm.mark_step() # Captures input updates torch_xla.sync(wait=False) # Captures input updates
super()._set_active_loras(prompt_lora_mapping, token_lora_mapping, super()._set_active_loras(prompt_lora_mapping, token_lora_mapping,
lora_requests) lora_requests)
xm.mark_step() # Captures metadata updates torch_xla.sync(wait=False) # Captures metadata updates
def _precompile_mm_encoder(self) -> None: def _precompile_mm_encoder(self) -> None:
if not self.supports_mm_inputs: if not self.supports_mm_inputs:
@ -1297,10 +1298,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_items, num_items,
) )
# Run multimodal encoder. # Run multimodal encoder.
xm.mark_step() torch_xla.sync(wait=False)
mm_embeds = self.model.get_multimodal_embeddings( mm_embeds = self.model.get_multimodal_embeddings(
**batched_dummy_mm_inputs) **batched_dummy_mm_inputs)
xm.mark_step() torch_xla.sync(wait=False)
num_patches = mm_embeds[0].shape[0] num_patches = mm_embeds[0].shape[0]
items_size = num_patches * num_items items_size = num_patches * num_items
@ -1325,7 +1326,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
a, b = self._get_model_inputs(placeholders_ids, a, b = self._get_model_inputs(placeholders_ids,
[mm_embeds]) [mm_embeds])
assert a is None assert a is None
xm.mark_step() torch_xla.sync(wait=False)
# Pre-compile `get_input_embeddings` when mm_embeddings are not # Pre-compile `get_input_embeddings` when mm_embeddings are not
# present. Chunk is only made of text, no mm_placeholders. # present. Chunk is only made of text, no mm_placeholders.
@ -1336,7 +1337,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
placeholders_ids = placeholders_ids.to(self.device) placeholders_ids = placeholders_ids.to(self.device)
a, b = self._get_model_inputs(placeholders_ids, []) a, b = self._get_model_inputs(placeholders_ids, [])
assert a is None assert a is None
xm.mark_step() torch_xla.sync(wait=False)
xm.wait_device_ops() xm.wait_device_ops()
end = time.perf_counter() end = time.perf_counter()
@ -1532,11 +1533,11 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Isolate encoder graph from post-processing to minimize # Isolate encoder graph from post-processing to minimize
# impact of recompilation until it's fixed. # impact of recompilation until it's fixed.
start = time.perf_counter() start = time.perf_counter()
xm.mark_step() torch_xla.sync(wait=False)
dummy_encoder_outputs = \ dummy_encoder_outputs = \
self.model.get_multimodal_embeddings( self.model.get_multimodal_embeddings(
**batched_dummy_mm_inputs) **batched_dummy_mm_inputs)
xm.mark_step() torch_xla.sync(wait=False)
xm.wait_device_ops() xm.wait_device_ops()
end = time.perf_counter() end = time.perf_counter()
logger.info( logger.info(
@ -1559,7 +1560,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self._dummy_run(num_tokens, self.num_reqs_most_model_len, self._dummy_run(num_tokens, self.num_reqs_most_model_len,
self.num_blocks_per_most_len_req) self.num_blocks_per_most_len_req)
xm.mark_step() torch_xla.sync(wait=False)
xm.wait_device_ops() xm.wait_device_ops()
self.encoder_cache.clear() self.encoder_cache.clear()
gc.collect() gc.collect()
@ -1927,11 +1928,11 @@ def replace_set_lora(model):
# to a tensor doesn't seem to work anymore. This might be fixed with a # to a tensor doesn't seem to work anymore. This might be fixed with a
# later release of torch_xla. # later release of torch_xla.
self._original_set_lora(index, lora_a, lora_b, embeddings_tensor, bias) self._original_set_lora(index, lora_a, lora_b, embeddings_tensor, bias)
xm.mark_step() torch_xla.sync(wait=False)
def _tpu_reset_lora(self, index: int): def _tpu_reset_lora(self, index: int):
self._original_reset_lora(index) self._original_reset_lora(index)
xm.mark_step() torch_xla.sync(wait=False)
for _, module in model.named_modules(): for _, module in model.named_modules():
if isinstance(module, BaseLayerWithLoRA): if isinstance(module, BaseLayerWithLoRA):