[Bugfix] Fix auto dtype casting for BatchFeature (#19316)

Signed-off-by: Isotr0py <2037008807@qq.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py 2025-06-14 23:13:08 +08:00 committed by GitHub
parent 6fa718a460
commit 2db9044ab6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 85 additions and 57 deletions

View File

@ -15,6 +15,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs import PromptType from vllm.inputs import PromptType
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind from vllm.sampling_params import RequestOutputKind
from vllm.utils import set_default_torch_num_threads
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.metrics.loggers import LoggingStatLogger from vllm.v1.metrics.loggers import LoggingStatLogger
@ -107,6 +108,7 @@ async def test_load(
with monkeypatch.context() as m, ExitStack() as after: with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(engine_args) engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown) after.callback(engine.shutdown)
@ -154,6 +156,7 @@ async def test_abort(
with monkeypatch.context() as m, ExitStack() as after: with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(engine_args) engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown) after.callback(engine.shutdown)
@ -226,6 +229,7 @@ async def test_finished_flag(
with monkeypatch.context() as m, ExitStack() as after: with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(engine_args) engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown) after.callback(engine.shutdown)
@ -260,6 +264,7 @@ async def test_mid_stream_cancellation(monkeypatch: pytest.MonkeyPatch,
with monkeypatch.context() as m, ExitStack() as after: with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(engine_args) engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown) after.callback(engine.shutdown)
@ -322,6 +327,7 @@ async def test_customize_loggers(monkeypatch):
with monkeypatch.context() as m, ExitStack() as after: with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args( engine = AsyncLLM.from_engine_args(
TEXT_ENGINE_ARGS, TEXT_ENGINE_ARGS,
stat_loggers=[MockLoggingStatLogger], stat_loggers=[MockLoggingStatLogger],
@ -340,6 +346,7 @@ async def test_dp_rank_argument(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m, ExitStack() as after: with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS) engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
after.callback(engine.shutdown) after.callback(engine.shutdown)

View File

@ -12,6 +12,7 @@ from transformers import AutoTokenizer
from vllm import SamplingParams from vllm import SamplingParams
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import set_default_torch_num_threads
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core import EngineCore from vllm.v1.engine.core import EngineCore
from vllm.v1.executor.abstract import Executor, UniProcExecutor from vllm.v1.executor.abstract import Executor, UniProcExecutor
@ -56,6 +57,7 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
vllm_config = engine_args.create_engine_config() vllm_config = engine_args.create_engine_config()
executor_class = Executor.get_class(vllm_config) executor_class = Executor.get_class(vllm_config)
with set_default_torch_num_threads(1):
engine_core = EngineCore(vllm_config=vllm_config, engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class, executor_class=executor_class,
log_stats=True) log_stats=True)
@ -190,6 +192,7 @@ def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch):
vllm_config = engine_args.create_engine_config() vllm_config = engine_args.create_engine_config()
executor_class = Executor.get_class(vllm_config) executor_class = Executor.get_class(vllm_config)
with set_default_torch_num_threads(1):
engine_core = EngineCore(vllm_config=vllm_config, engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class, executor_class=executor_class,
log_stats=True) log_stats=True)
@ -286,6 +289,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
enforce_eager=True, enforce_eager=True,
) )
vllm_config = engine_args.create_engine_config() vllm_config = engine_args.create_engine_config()
with set_default_torch_num_threads(1):
engine_core = EngineCore(vllm_config=vllm_config, engine_core = EngineCore(vllm_config=vllm_config,
log_stats=False, log_stats=False,
executor_class=DummyExecutor) executor_class=DummyExecutor)

View File

@ -19,6 +19,7 @@ from vllm.distributed.kv_events import (BlockStored, KVEventBatch,
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import set_default_torch_num_threads
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core import EngineCore from vllm.v1.engine.core import EngineCore
from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient, from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient,
@ -138,6 +139,8 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch,
vllm_config = engine_args.create_engine_config( vllm_config = engine_args.create_engine_config(
UsageContext.UNKNOWN_CONTEXT) UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config) executor_class = Executor.get_class(vllm_config)
with set_default_torch_num_threads(1):
client = EngineCoreClient.make_client( client = EngineCoreClient.make_client(
multiprocess_mode=multiprocessing_mode, multiprocess_mode=multiprocessing_mode,
asyncio_mode=False, asyncio_mode=False,
@ -223,6 +226,8 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
vllm_config = engine_args.create_engine_config( vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT) usage_context=UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config) executor_class = Executor.get_class(vllm_config)
with set_default_torch_num_threads(1):
client = EngineCoreClient.make_client( client = EngineCoreClient.make_client(
multiprocess_mode=True, multiprocess_mode=True,
asyncio_mode=True, asyncio_mode=True,
@ -312,6 +317,7 @@ def test_kv_cache_events(
UsageContext.UNKNOWN_CONTEXT) UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config) executor_class = Executor.get_class(vllm_config)
with set_default_torch_num_threads(1):
client = EngineCoreClient.make_client( client = EngineCoreClient.make_client(
multiprocess_mode=multiprocessing_mode, multiprocess_mode=multiprocessing_mode,
asyncio_mode=False, asyncio_mode=False,
@ -394,6 +400,7 @@ async def test_kv_cache_events_dp(
UsageContext.UNKNOWN_CONTEXT) UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config) executor_class = Executor.get_class(vllm_config)
with set_default_torch_num_threads(1):
client = EngineCoreClient.make_client( client = EngineCoreClient.make_client(
multiprocess_mode=multiprocessing_mode, multiprocess_mode=multiprocessing_mode,
asyncio_mode=True, asyncio_mode=True,

View File

@ -168,10 +168,12 @@ class InputProcessingContext(InputContext):
try: try:
output = hf_processor(**data, **merged_kwargs, return_tensors="pt") output = hf_processor(**data, **merged_kwargs, return_tensors="pt")
# this emulates output.to(dtype=self.model_config.dtype) # this emulates output.to(dtype=self.model_config.dtype)
cast_output = json_map_leaves(maybe_cast_dtype, output)
if isinstance(output, BatchFeature): if isinstance(output, BatchFeature):
cast_output = json_map_leaves(maybe_cast_dtype, output.data)
return BatchFeature(cast_output) return BatchFeature(cast_output)
cast_output = json_map_leaves(maybe_cast_dtype, output)
logger.warning_once( logger.warning_once(
f"{type(hf_processor).__name__} did not return `BatchFeature`. " f"{type(hf_processor).__name__} did not return `BatchFeature`. "
"Make sure to match the behaviour of `ProcessorMixin` when " "Make sure to match the behaviour of `ProcessorMixin` when "

View File

@ -965,9 +965,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
grid_thw_list = grid_thw.tolist() grid_thw_list = grid_thw.tolist()
if image_input["type"] == "image_embeds": if image_input["type"] == "image_embeds":
image_embeds = image_input["image_embeds"].type(self.visual.dtype) image_embeds = image_input["image_embeds"]
else: else:
pixel_values = image_input["pixel_values"].type(self.visual.dtype) pixel_values = image_input["pixel_values"]
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
# Split concatenated embeddings for each image item. # Split concatenated embeddings for each image item.
@ -985,10 +985,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
grid_thw_list = grid_thw.tolist() grid_thw_list = grid_thw.tolist()
if video_input["type"] == "video_embeds": if video_input["type"] == "video_embeds":
video_embeds = video_input["video_embeds"].type(self.visual.dtype) video_embeds = video_input["video_embeds"]
else: else:
pixel_values_videos = video_input["pixel_values_videos"].type( pixel_values_videos = video_input["pixel_values_videos"]
self.visual.dtype)
video_embeds = self.visual(pixel_values_videos, video_embeds = self.visual(pixel_values_videos,
grid_thw=grid_thw_list) grid_thw=grid_thw_list)

View File

@ -1208,9 +1208,9 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
assert grid_thw.ndim == 2 assert grid_thw.ndim == 2
if image_input["type"] == "image_embeds": if image_input["type"] == "image_embeds":
image_embeds = image_input["image_embeds"].type(self.visual.dtype) image_embeds = image_input["image_embeds"]
else: else:
pixel_values = image_input["pixel_values"].type(self.visual.dtype) pixel_values = image_input["pixel_values"]
image_embeds = self.visual(pixel_values, grid_thw=grid_thw) image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
# Split concatenated embeddings for each image item. # Split concatenated embeddings for each image item.
@ -1226,10 +1226,9 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
assert grid_thw.ndim == 2 assert grid_thw.ndim == 2
if video_input["type"] == "video_embeds": if video_input["type"] == "video_embeds":
video_embeds = video_input["video_embeds"].type(self.visual.dtype) video_embeds = video_input["video_embeds"]
else: else:
pixel_values_videos = video_input["pixel_values_videos"].type( pixel_values_videos = video_input["pixel_values_videos"]
self.visual.dtype)
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
# Split concatenated embeddings for each video item. # Split concatenated embeddings for each video item.

View File

@ -190,6 +190,16 @@ TORCH_DTYPE_TO_NUMPY_DTYPE = {
torch.int64: np.int64, torch.int64: np.int64,
} }
@contextlib.contextmanager
def set_default_torch_num_threads(num_threads: int):
"""Sets the default number of threads for PyTorch to the given value."""
old_num_threads = torch.get_num_threads()
torch.set_num_threads(num_threads)
yield
torch.set_num_threads(old_num_threads)
P = ParamSpec('P') P = ParamSpec('P')
T = TypeVar("T") T = TypeVar("T")
U = TypeVar("U") U = TypeVar("U")