diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index 7dff937c0fd9..3ae629397268 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -15,6 +15,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.inputs import PromptType from vllm.platforms import current_platform from vllm.sampling_params import RequestOutputKind +from vllm.utils import set_default_torch_num_threads from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.metrics.loggers import LoggingStatLogger @@ -107,7 +108,8 @@ async def test_load( with monkeypatch.context() as m, ExitStack() as after: m.setenv("VLLM_USE_V1", "1") - engine = AsyncLLM.from_engine_args(engine_args) + with set_default_torch_num_threads(1): + engine = AsyncLLM.from_engine_args(engine_args) after.callback(engine.shutdown) NUM_REQUESTS = 100 @@ -154,7 +156,8 @@ async def test_abort( with monkeypatch.context() as m, ExitStack() as after: m.setenv("VLLM_USE_V1", "1") - engine = AsyncLLM.from_engine_args(engine_args) + with set_default_torch_num_threads(1): + engine = AsyncLLM.from_engine_args(engine_args) after.callback(engine.shutdown) NUM_REQUESTS = 100 @@ -226,7 +229,8 @@ async def test_finished_flag( with monkeypatch.context() as m, ExitStack() as after: m.setenv("VLLM_USE_V1", "1") - engine = AsyncLLM.from_engine_args(engine_args) + with set_default_torch_num_threads(1): + engine = AsyncLLM.from_engine_args(engine_args) after.callback(engine.shutdown) sampling_params = SamplingParams( @@ -260,7 +264,8 @@ async def test_mid_stream_cancellation(monkeypatch: pytest.MonkeyPatch, with monkeypatch.context() as m, ExitStack() as after: m.setenv("VLLM_USE_V1", "1") - engine = AsyncLLM.from_engine_args(engine_args) + with set_default_torch_num_threads(1): + engine = AsyncLLM.from_engine_args(engine_args) after.callback(engine.shutdown) NUM_REQUESTS = 100 @@ -322,10 +327,11 @@ async def test_customize_loggers(monkeypatch): with monkeypatch.context() as m, ExitStack() as after: m.setenv("VLLM_USE_V1", "1") - engine = AsyncLLM.from_engine_args( - TEXT_ENGINE_ARGS, - stat_loggers=[MockLoggingStatLogger], - ) + with set_default_torch_num_threads(1): + engine = AsyncLLM.from_engine_args( + TEXT_ENGINE_ARGS, + stat_loggers=[MockLoggingStatLogger], + ) after.callback(engine.shutdown) await engine.do_log_stats() @@ -340,7 +346,8 @@ async def test_dp_rank_argument(monkeypatch: pytest.MonkeyPatch): with monkeypatch.context() as m, ExitStack() as after: m.setenv("VLLM_USE_V1", "1") - engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS) + with set_default_torch_num_threads(1): + engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS) after.callback(engine.shutdown) sampling_params = SamplingParams(max_tokens=100, diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 1cbbf30371af..fbbfc630d27d 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -12,6 +12,7 @@ from transformers import AutoTokenizer from vllm import SamplingParams from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform +from vllm.utils import set_default_torch_num_threads from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core import EngineCore from vllm.v1.executor.abstract import Executor, UniProcExecutor @@ -56,9 +57,10 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch): vllm_config = engine_args.create_engine_config() executor_class = Executor.get_class(vllm_config) - engine_core = EngineCore(vllm_config=vllm_config, - executor_class=executor_class, - log_stats=True) + with set_default_torch_num_threads(1): + engine_core = EngineCore(vllm_config=vllm_config, + executor_class=executor_class, + log_stats=True) """Test basic request lifecycle.""" # First request. @@ -190,9 +192,10 @@ def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch): vllm_config = engine_args.create_engine_config() executor_class = Executor.get_class(vllm_config) - engine_core = EngineCore(vllm_config=vllm_config, - executor_class=executor_class, - log_stats=True) + with set_default_torch_num_threads(1): + engine_core = EngineCore(vllm_config=vllm_config, + executor_class=executor_class, + log_stats=True) """Test basic request lifecycle.""" # First request. request: EngineCoreRequest = make_request() @@ -286,9 +289,10 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): enforce_eager=True, ) vllm_config = engine_args.create_engine_config() - engine_core = EngineCore(vllm_config=vllm_config, - log_stats=False, - executor_class=DummyExecutor) + with set_default_torch_num_threads(1): + engine_core = EngineCore(vllm_config=vllm_config, + log_stats=False, + executor_class=DummyExecutor) assert engine_core.batch_queue is not None # Add two requests in a row. Each request have 12 prompt tokens. diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index c2dc3b4731b5..d4db16fe86fa 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -19,6 +19,7 @@ from vllm.distributed.kv_events import (BlockStored, KVEventBatch, from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform from vllm.usage.usage_lib import UsageContext +from vllm.utils import set_default_torch_num_threads from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core import EngineCore from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient, @@ -138,13 +139,15 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch, vllm_config = engine_args.create_engine_config( UsageContext.UNKNOWN_CONTEXT) executor_class = Executor.get_class(vllm_config) - client = EngineCoreClient.make_client( - multiprocess_mode=multiprocessing_mode, - asyncio_mode=False, - vllm_config=vllm_config, - executor_class=executor_class, - log_stats=False, - ) + + with set_default_torch_num_threads(1): + client = EngineCoreClient.make_client( + multiprocess_mode=multiprocessing_mode, + asyncio_mode=False, + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=False, + ) MAX_TOKENS = 20 params = SamplingParams(max_tokens=MAX_TOKENS) @@ -223,13 +226,15 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch): vllm_config = engine_args.create_engine_config( usage_context=UsageContext.UNKNOWN_CONTEXT) executor_class = Executor.get_class(vllm_config) - client = EngineCoreClient.make_client( - multiprocess_mode=True, - asyncio_mode=True, - vllm_config=vllm_config, - executor_class=executor_class, - log_stats=True, - ) + + with set_default_torch_num_threads(1): + client = EngineCoreClient.make_client( + multiprocess_mode=True, + asyncio_mode=True, + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=True, + ) try: MAX_TOKENS = 20 @@ -312,13 +317,14 @@ def test_kv_cache_events( UsageContext.UNKNOWN_CONTEXT) executor_class = Executor.get_class(vllm_config) - client = EngineCoreClient.make_client( - multiprocess_mode=multiprocessing_mode, - asyncio_mode=False, - vllm_config=vllm_config, - executor_class=executor_class, - log_stats=False, - ) + with set_default_torch_num_threads(1): + client = EngineCoreClient.make_client( + multiprocess_mode=multiprocessing_mode, + asyncio_mode=False, + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=False, + ) endpoint = publisher_config.endpoint.replace("*", "127.0.0.1") subscriber = MockSubscriber(endpoint, topic=publisher_config.topic, @@ -394,13 +400,14 @@ async def test_kv_cache_events_dp( UsageContext.UNKNOWN_CONTEXT) executor_class = Executor.get_class(vllm_config) - client = EngineCoreClient.make_client( - multiprocess_mode=multiprocessing_mode, - asyncio_mode=True, - vllm_config=vllm_config, - executor_class=executor_class, - log_stats=False, - ) + with set_default_torch_num_threads(1): + client = EngineCoreClient.make_client( + multiprocess_mode=multiprocessing_mode, + asyncio_mode=True, + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=False, + ) await asyncio.sleep(1) # Build endpoints for all DP ranks diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 3dad021e3166..66e78833f52a 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -168,10 +168,12 @@ class InputProcessingContext(InputContext): try: output = hf_processor(**data, **merged_kwargs, return_tensors="pt") # this emulates output.to(dtype=self.model_config.dtype) - cast_output = json_map_leaves(maybe_cast_dtype, output) if isinstance(output, BatchFeature): + cast_output = json_map_leaves(maybe_cast_dtype, output.data) return BatchFeature(cast_output) + cast_output = json_map_leaves(maybe_cast_dtype, output) + logger.warning_once( f"{type(hf_processor).__name__} did not return `BatchFeature`. " "Make sure to match the behaviour of `ProcessorMixin` when " diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 7770ec711ce7..73d241921bcf 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -965,9 +965,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, grid_thw_list = grid_thw.tolist() if image_input["type"] == "image_embeds": - image_embeds = image_input["image_embeds"].type(self.visual.dtype) + image_embeds = image_input["image_embeds"] else: - pixel_values = image_input["pixel_values"].type(self.visual.dtype) + pixel_values = image_input["pixel_values"] image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) # Split concatenated embeddings for each image item. @@ -985,10 +985,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, grid_thw_list = grid_thw.tolist() if video_input["type"] == "video_embeds": - video_embeds = video_input["video_embeds"].type(self.visual.dtype) + video_embeds = video_input["video_embeds"] else: - pixel_values_videos = video_input["pixel_values_videos"].type( - self.visual.dtype) + pixel_values_videos = video_input["pixel_values_videos"] video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index a4f8a361ec71..d8318fff868e 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -1208,9 +1208,9 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, assert grid_thw.ndim == 2 if image_input["type"] == "image_embeds": - image_embeds = image_input["image_embeds"].type(self.visual.dtype) + image_embeds = image_input["image_embeds"] else: - pixel_values = image_input["pixel_values"].type(self.visual.dtype) + pixel_values = image_input["pixel_values"] image_embeds = self.visual(pixel_values, grid_thw=grid_thw) # Split concatenated embeddings for each image item. @@ -1226,10 +1226,9 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, assert grid_thw.ndim == 2 if video_input["type"] == "video_embeds": - video_embeds = video_input["video_embeds"].type(self.visual.dtype) + video_embeds = video_input["video_embeds"] else: - pixel_values_videos = video_input["pixel_values_videos"].type( - self.visual.dtype) + pixel_values_videos = video_input["pixel_values_videos"] video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) # Split concatenated embeddings for each video item. diff --git a/vllm/utils.py b/vllm/utils.py index 342241d0dd8a..dc408e1676f1 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -190,6 +190,16 @@ TORCH_DTYPE_TO_NUMPY_DTYPE = { torch.int64: np.int64, } + +@contextlib.contextmanager +def set_default_torch_num_threads(num_threads: int): + """Sets the default number of threads for PyTorch to the given value.""" + old_num_threads = torch.get_num_threads() + torch.set_num_threads(num_threads) + yield + torch.set_num_threads(old_num_threads) + + P = ParamSpec('P') T = TypeVar("T") U = TypeVar("U")