mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-13 10:13:31 +08:00
[bugfix] fix cpu tests (#10585)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
d345f409b7
commit
d559979c54
@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||||
from vllm.multimodal import MultiModalKwargs
|
from vllm.multimodal import MultiModalKwargs
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
@ -64,7 +65,8 @@ class CPUEmbeddingModelRunner(
|
|||||||
intermediate_tensors,
|
intermediate_tensors,
|
||||||
}
|
}
|
||||||
|
|
||||||
hidden_states = model_executable(**execute_model_kwargs)
|
with set_forward_context(model_input.attn_metadata, self.vllm_config):
|
||||||
|
hidden_states = model_executable(**execute_model_kwargs)
|
||||||
|
|
||||||
# Only perform pooling in the driver worker.
|
# Only perform pooling in the driver worker.
|
||||||
if not self.is_driver_worker:
|
if not self.is_driver_worker:
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, cast
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
from vllm.attention import AttentionMetadata
|
||||||
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.model_executor import SamplingMetadata
|
from vllm.model_executor import SamplingMetadata
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
from vllm.multimodal import MultiModalKwargs
|
from vllm.multimodal import MultiModalKwargs
|
||||||
@ -303,7 +304,8 @@ class CPUEncoderDecoderModelRunner(
|
|||||||
intermediate_tensors,
|
intermediate_tensors,
|
||||||
}
|
}
|
||||||
|
|
||||||
hidden_states = model_executable(**execute_model_kwargs)
|
with set_forward_context(model_input.attn_metadata, self.vllm_config):
|
||||||
|
hidden_states = model_executable(**execute_model_kwargs)
|
||||||
|
|
||||||
# Compute the logits.
|
# Compute the logits.
|
||||||
logits = self.model.compute_logits(hidden_states,
|
logits = self.model.compute_logits(hidden_states,
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from torch import nn
|
|||||||
|
|
||||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor import SamplingMetadata
|
from vllm.model_executor import SamplingMetadata
|
||||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||||
@ -487,14 +488,15 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
|
|||||||
multimodal_kwargs = MultiModalKwargs.as_kwargs(
|
multimodal_kwargs = MultiModalKwargs.as_kwargs(
|
||||||
model_input.multi_modal_kwargs, device=self.device)
|
model_input.multi_modal_kwargs, device=self.device)
|
||||||
|
|
||||||
hidden_states = model_executable(
|
with set_forward_context(model_input.attn_metadata, self.vllm_config):
|
||||||
input_ids=model_input.input_tokens,
|
hidden_states = model_executable(
|
||||||
positions=model_input.input_positions,
|
input_ids=model_input.input_tokens,
|
||||||
kv_caches=kv_caches,
|
positions=model_input.input_positions,
|
||||||
attn_metadata=model_input.attn_metadata,
|
kv_caches=kv_caches,
|
||||||
intermediate_tensors=intermediate_tensors,
|
attn_metadata=model_input.attn_metadata,
|
||||||
**multimodal_kwargs,
|
intermediate_tensors=intermediate_tensors,
|
||||||
)
|
**multimodal_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
# Compute the logits.
|
# Compute the logits.
|
||||||
logits = self.model.compute_logits(hidden_states,
|
logits = self.model.compute_logits(hidden_states,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user