mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 20:35:01 +08:00
[Bugfix] Fix auto dtype casting for BatchFeature (#19316)
Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
6fa718a460
commit
2db9044ab6
@ -15,6 +15,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
|
|||||||
from vllm.inputs import PromptType
|
from vllm.inputs import PromptType
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.sampling_params import RequestOutputKind
|
from vllm.sampling_params import RequestOutputKind
|
||||||
|
from vllm.utils import set_default_torch_num_threads
|
||||||
from vllm.v1.engine.async_llm import AsyncLLM
|
from vllm.v1.engine.async_llm import AsyncLLM
|
||||||
from vllm.v1.metrics.loggers import LoggingStatLogger
|
from vllm.v1.metrics.loggers import LoggingStatLogger
|
||||||
|
|
||||||
@ -107,7 +108,8 @@ async def test_load(
|
|||||||
with monkeypatch.context() as m, ExitStack() as after:
|
with monkeypatch.context() as m, ExitStack() as after:
|
||||||
m.setenv("VLLM_USE_V1", "1")
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
engine = AsyncLLM.from_engine_args(engine_args)
|
with set_default_torch_num_threads(1):
|
||||||
|
engine = AsyncLLM.from_engine_args(engine_args)
|
||||||
after.callback(engine.shutdown)
|
after.callback(engine.shutdown)
|
||||||
|
|
||||||
NUM_REQUESTS = 100
|
NUM_REQUESTS = 100
|
||||||
@ -154,7 +156,8 @@ async def test_abort(
|
|||||||
with monkeypatch.context() as m, ExitStack() as after:
|
with monkeypatch.context() as m, ExitStack() as after:
|
||||||
m.setenv("VLLM_USE_V1", "1")
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
engine = AsyncLLM.from_engine_args(engine_args)
|
with set_default_torch_num_threads(1):
|
||||||
|
engine = AsyncLLM.from_engine_args(engine_args)
|
||||||
after.callback(engine.shutdown)
|
after.callback(engine.shutdown)
|
||||||
|
|
||||||
NUM_REQUESTS = 100
|
NUM_REQUESTS = 100
|
||||||
@ -226,7 +229,8 @@ async def test_finished_flag(
|
|||||||
with monkeypatch.context() as m, ExitStack() as after:
|
with monkeypatch.context() as m, ExitStack() as after:
|
||||||
m.setenv("VLLM_USE_V1", "1")
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
engine = AsyncLLM.from_engine_args(engine_args)
|
with set_default_torch_num_threads(1):
|
||||||
|
engine = AsyncLLM.from_engine_args(engine_args)
|
||||||
after.callback(engine.shutdown)
|
after.callback(engine.shutdown)
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
@ -260,7 +264,8 @@ async def test_mid_stream_cancellation(monkeypatch: pytest.MonkeyPatch,
|
|||||||
with monkeypatch.context() as m, ExitStack() as after:
|
with monkeypatch.context() as m, ExitStack() as after:
|
||||||
m.setenv("VLLM_USE_V1", "1")
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
engine = AsyncLLM.from_engine_args(engine_args)
|
with set_default_torch_num_threads(1):
|
||||||
|
engine = AsyncLLM.from_engine_args(engine_args)
|
||||||
after.callback(engine.shutdown)
|
after.callback(engine.shutdown)
|
||||||
|
|
||||||
NUM_REQUESTS = 100
|
NUM_REQUESTS = 100
|
||||||
@ -322,10 +327,11 @@ async def test_customize_loggers(monkeypatch):
|
|||||||
with monkeypatch.context() as m, ExitStack() as after:
|
with monkeypatch.context() as m, ExitStack() as after:
|
||||||
m.setenv("VLLM_USE_V1", "1")
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
engine = AsyncLLM.from_engine_args(
|
with set_default_torch_num_threads(1):
|
||||||
TEXT_ENGINE_ARGS,
|
engine = AsyncLLM.from_engine_args(
|
||||||
stat_loggers=[MockLoggingStatLogger],
|
TEXT_ENGINE_ARGS,
|
||||||
)
|
stat_loggers=[MockLoggingStatLogger],
|
||||||
|
)
|
||||||
after.callback(engine.shutdown)
|
after.callback(engine.shutdown)
|
||||||
|
|
||||||
await engine.do_log_stats()
|
await engine.do_log_stats()
|
||||||
@ -340,7 +346,8 @@ async def test_dp_rank_argument(monkeypatch: pytest.MonkeyPatch):
|
|||||||
with monkeypatch.context() as m, ExitStack() as after:
|
with monkeypatch.context() as m, ExitStack() as after:
|
||||||
m.setenv("VLLM_USE_V1", "1")
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
|
with set_default_torch_num_threads(1):
|
||||||
|
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
|
||||||
after.callback(engine.shutdown)
|
after.callback(engine.shutdown)
|
||||||
|
|
||||||
sampling_params = SamplingParams(max_tokens=100,
|
sampling_params = SamplingParams(max_tokens=100,
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from transformers import AutoTokenizer
|
|||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import set_default_torch_num_threads
|
||||||
from vllm.v1.engine import EngineCoreRequest
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
from vllm.v1.engine.core import EngineCore
|
from vllm.v1.engine.core import EngineCore
|
||||||
from vllm.v1.executor.abstract import Executor, UniProcExecutor
|
from vllm.v1.executor.abstract import Executor, UniProcExecutor
|
||||||
@ -56,9 +57,10 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
|
|||||||
vllm_config = engine_args.create_engine_config()
|
vllm_config = engine_args.create_engine_config()
|
||||||
executor_class = Executor.get_class(vllm_config)
|
executor_class = Executor.get_class(vllm_config)
|
||||||
|
|
||||||
engine_core = EngineCore(vllm_config=vllm_config,
|
with set_default_torch_num_threads(1):
|
||||||
executor_class=executor_class,
|
engine_core = EngineCore(vllm_config=vllm_config,
|
||||||
log_stats=True)
|
executor_class=executor_class,
|
||||||
|
log_stats=True)
|
||||||
"""Test basic request lifecycle."""
|
"""Test basic request lifecycle."""
|
||||||
|
|
||||||
# First request.
|
# First request.
|
||||||
@ -190,9 +192,10 @@ def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch):
|
|||||||
vllm_config = engine_args.create_engine_config()
|
vllm_config = engine_args.create_engine_config()
|
||||||
executor_class = Executor.get_class(vllm_config)
|
executor_class = Executor.get_class(vllm_config)
|
||||||
|
|
||||||
engine_core = EngineCore(vllm_config=vllm_config,
|
with set_default_torch_num_threads(1):
|
||||||
executor_class=executor_class,
|
engine_core = EngineCore(vllm_config=vllm_config,
|
||||||
log_stats=True)
|
executor_class=executor_class,
|
||||||
|
log_stats=True)
|
||||||
"""Test basic request lifecycle."""
|
"""Test basic request lifecycle."""
|
||||||
# First request.
|
# First request.
|
||||||
request: EngineCoreRequest = make_request()
|
request: EngineCoreRequest = make_request()
|
||||||
@ -286,9 +289,10 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
|||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
)
|
)
|
||||||
vllm_config = engine_args.create_engine_config()
|
vllm_config = engine_args.create_engine_config()
|
||||||
engine_core = EngineCore(vllm_config=vllm_config,
|
with set_default_torch_num_threads(1):
|
||||||
log_stats=False,
|
engine_core = EngineCore(vllm_config=vllm_config,
|
||||||
executor_class=DummyExecutor)
|
log_stats=False,
|
||||||
|
executor_class=DummyExecutor)
|
||||||
assert engine_core.batch_queue is not None
|
assert engine_core.batch_queue is not None
|
||||||
|
|
||||||
# Add two requests in a row. Each request have 12 prompt tokens.
|
# Add two requests in a row. Each request have 12 prompt tokens.
|
||||||
|
|||||||
@ -19,6 +19,7 @@ from vllm.distributed.kv_events import (BlockStored, KVEventBatch,
|
|||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
|
from vllm.utils import set_default_torch_num_threads
|
||||||
from vllm.v1.engine import EngineCoreRequest
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
from vllm.v1.engine.core import EngineCore
|
from vllm.v1.engine.core import EngineCore
|
||||||
from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient,
|
from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient,
|
||||||
@ -138,13 +139,15 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch,
|
|||||||
vllm_config = engine_args.create_engine_config(
|
vllm_config = engine_args.create_engine_config(
|
||||||
UsageContext.UNKNOWN_CONTEXT)
|
UsageContext.UNKNOWN_CONTEXT)
|
||||||
executor_class = Executor.get_class(vllm_config)
|
executor_class = Executor.get_class(vllm_config)
|
||||||
client = EngineCoreClient.make_client(
|
|
||||||
multiprocess_mode=multiprocessing_mode,
|
with set_default_torch_num_threads(1):
|
||||||
asyncio_mode=False,
|
client = EngineCoreClient.make_client(
|
||||||
vllm_config=vllm_config,
|
multiprocess_mode=multiprocessing_mode,
|
||||||
executor_class=executor_class,
|
asyncio_mode=False,
|
||||||
log_stats=False,
|
vllm_config=vllm_config,
|
||||||
)
|
executor_class=executor_class,
|
||||||
|
log_stats=False,
|
||||||
|
)
|
||||||
|
|
||||||
MAX_TOKENS = 20
|
MAX_TOKENS = 20
|
||||||
params = SamplingParams(max_tokens=MAX_TOKENS)
|
params = SamplingParams(max_tokens=MAX_TOKENS)
|
||||||
@ -223,13 +226,15 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
|
|||||||
vllm_config = engine_args.create_engine_config(
|
vllm_config = engine_args.create_engine_config(
|
||||||
usage_context=UsageContext.UNKNOWN_CONTEXT)
|
usage_context=UsageContext.UNKNOWN_CONTEXT)
|
||||||
executor_class = Executor.get_class(vllm_config)
|
executor_class = Executor.get_class(vllm_config)
|
||||||
client = EngineCoreClient.make_client(
|
|
||||||
multiprocess_mode=True,
|
with set_default_torch_num_threads(1):
|
||||||
asyncio_mode=True,
|
client = EngineCoreClient.make_client(
|
||||||
vllm_config=vllm_config,
|
multiprocess_mode=True,
|
||||||
executor_class=executor_class,
|
asyncio_mode=True,
|
||||||
log_stats=True,
|
vllm_config=vllm_config,
|
||||||
)
|
executor_class=executor_class,
|
||||||
|
log_stats=True,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
MAX_TOKENS = 20
|
MAX_TOKENS = 20
|
||||||
@ -312,13 +317,14 @@ def test_kv_cache_events(
|
|||||||
UsageContext.UNKNOWN_CONTEXT)
|
UsageContext.UNKNOWN_CONTEXT)
|
||||||
|
|
||||||
executor_class = Executor.get_class(vllm_config)
|
executor_class = Executor.get_class(vllm_config)
|
||||||
client = EngineCoreClient.make_client(
|
with set_default_torch_num_threads(1):
|
||||||
multiprocess_mode=multiprocessing_mode,
|
client = EngineCoreClient.make_client(
|
||||||
asyncio_mode=False,
|
multiprocess_mode=multiprocessing_mode,
|
||||||
vllm_config=vllm_config,
|
asyncio_mode=False,
|
||||||
executor_class=executor_class,
|
vllm_config=vllm_config,
|
||||||
log_stats=False,
|
executor_class=executor_class,
|
||||||
)
|
log_stats=False,
|
||||||
|
)
|
||||||
endpoint = publisher_config.endpoint.replace("*", "127.0.0.1")
|
endpoint = publisher_config.endpoint.replace("*", "127.0.0.1")
|
||||||
subscriber = MockSubscriber(endpoint,
|
subscriber = MockSubscriber(endpoint,
|
||||||
topic=publisher_config.topic,
|
topic=publisher_config.topic,
|
||||||
@ -394,13 +400,14 @@ async def test_kv_cache_events_dp(
|
|||||||
UsageContext.UNKNOWN_CONTEXT)
|
UsageContext.UNKNOWN_CONTEXT)
|
||||||
|
|
||||||
executor_class = Executor.get_class(vllm_config)
|
executor_class = Executor.get_class(vllm_config)
|
||||||
client = EngineCoreClient.make_client(
|
with set_default_torch_num_threads(1):
|
||||||
multiprocess_mode=multiprocessing_mode,
|
client = EngineCoreClient.make_client(
|
||||||
asyncio_mode=True,
|
multiprocess_mode=multiprocessing_mode,
|
||||||
vllm_config=vllm_config,
|
asyncio_mode=True,
|
||||||
executor_class=executor_class,
|
vllm_config=vllm_config,
|
||||||
log_stats=False,
|
executor_class=executor_class,
|
||||||
)
|
log_stats=False,
|
||||||
|
)
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
# Build endpoints for all DP ranks
|
# Build endpoints for all DP ranks
|
||||||
|
|||||||
@ -168,10 +168,12 @@ class InputProcessingContext(InputContext):
|
|||||||
try:
|
try:
|
||||||
output = hf_processor(**data, **merged_kwargs, return_tensors="pt")
|
output = hf_processor(**data, **merged_kwargs, return_tensors="pt")
|
||||||
# this emulates output.to(dtype=self.model_config.dtype)
|
# this emulates output.to(dtype=self.model_config.dtype)
|
||||||
cast_output = json_map_leaves(maybe_cast_dtype, output)
|
|
||||||
if isinstance(output, BatchFeature):
|
if isinstance(output, BatchFeature):
|
||||||
|
cast_output = json_map_leaves(maybe_cast_dtype, output.data)
|
||||||
return BatchFeature(cast_output)
|
return BatchFeature(cast_output)
|
||||||
|
|
||||||
|
cast_output = json_map_leaves(maybe_cast_dtype, output)
|
||||||
|
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
f"{type(hf_processor).__name__} did not return `BatchFeature`. "
|
f"{type(hf_processor).__name__} did not return `BatchFeature`. "
|
||||||
"Make sure to match the behaviour of `ProcessorMixin` when "
|
"Make sure to match the behaviour of `ProcessorMixin` when "
|
||||||
|
|||||||
@ -965,9 +965,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
grid_thw_list = grid_thw.tolist()
|
grid_thw_list = grid_thw.tolist()
|
||||||
|
|
||||||
if image_input["type"] == "image_embeds":
|
if image_input["type"] == "image_embeds":
|
||||||
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
|
image_embeds = image_input["image_embeds"]
|
||||||
else:
|
else:
|
||||||
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
pixel_values = image_input["pixel_values"]
|
||||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
|
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
|
||||||
|
|
||||||
# Split concatenated embeddings for each image item.
|
# Split concatenated embeddings for each image item.
|
||||||
@ -985,10 +985,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
grid_thw_list = grid_thw.tolist()
|
grid_thw_list = grid_thw.tolist()
|
||||||
|
|
||||||
if video_input["type"] == "video_embeds":
|
if video_input["type"] == "video_embeds":
|
||||||
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
|
video_embeds = video_input["video_embeds"]
|
||||||
else:
|
else:
|
||||||
pixel_values_videos = video_input["pixel_values_videos"].type(
|
pixel_values_videos = video_input["pixel_values_videos"]
|
||||||
self.visual.dtype)
|
|
||||||
video_embeds = self.visual(pixel_values_videos,
|
video_embeds = self.visual(pixel_values_videos,
|
||||||
grid_thw=grid_thw_list)
|
grid_thw=grid_thw_list)
|
||||||
|
|
||||||
|
|||||||
@ -1208,9 +1208,9 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
assert grid_thw.ndim == 2
|
assert grid_thw.ndim == 2
|
||||||
|
|
||||||
if image_input["type"] == "image_embeds":
|
if image_input["type"] == "image_embeds":
|
||||||
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
|
image_embeds = image_input["image_embeds"]
|
||||||
else:
|
else:
|
||||||
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
pixel_values = image_input["pixel_values"]
|
||||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
|
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
|
||||||
|
|
||||||
# Split concatenated embeddings for each image item.
|
# Split concatenated embeddings for each image item.
|
||||||
@ -1226,10 +1226,9 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
assert grid_thw.ndim == 2
|
assert grid_thw.ndim == 2
|
||||||
|
|
||||||
if video_input["type"] == "video_embeds":
|
if video_input["type"] == "video_embeds":
|
||||||
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
|
video_embeds = video_input["video_embeds"]
|
||||||
else:
|
else:
|
||||||
pixel_values_videos = video_input["pixel_values_videos"].type(
|
pixel_values_videos = video_input["pixel_values_videos"]
|
||||||
self.visual.dtype)
|
|
||||||
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
|
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
|
||||||
|
|
||||||
# Split concatenated embeddings for each video item.
|
# Split concatenated embeddings for each video item.
|
||||||
|
|||||||
@ -190,6 +190,16 @@ TORCH_DTYPE_TO_NUMPY_DTYPE = {
|
|||||||
torch.int64: np.int64,
|
torch.int64: np.int64,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def set_default_torch_num_threads(num_threads: int):
|
||||||
|
"""Sets the default number of threads for PyTorch to the given value."""
|
||||||
|
old_num_threads = torch.get_num_threads()
|
||||||
|
torch.set_num_threads(num_threads)
|
||||||
|
yield
|
||||||
|
torch.set_num_threads(old_num_threads)
|
||||||
|
|
||||||
|
|
||||||
P = ParamSpec('P')
|
P = ParamSpec('P')
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
U = TypeVar("U")
|
U = TypeVar("U")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user