[bugfix] fix cpu tests (#10585)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-11-22 17:34:03 -08:00 committed by GitHub
parent d345f409b7
commit d559979c54
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 16 additions and 10 deletions

View File

@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
import torch import torch
from vllm.forward_context import set_forward_context
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MultiModalKwargs from vllm.multimodal import MultiModalKwargs
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
@ -64,7 +65,8 @@ class CPUEmbeddingModelRunner(
intermediate_tensors, intermediate_tensors,
} }
hidden_states = model_executable(**execute_model_kwargs) with set_forward_context(model_input.attn_metadata, self.vllm_config):
hidden_states = model_executable(**execute_model_kwargs)
# Only perform pooling in the driver worker. # Only perform pooling in the driver worker.
if not self.is_driver_worker: if not self.is_driver_worker:

View File

@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, cast
import torch import torch
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.forward_context import set_forward_context
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MultiModalKwargs from vllm.multimodal import MultiModalKwargs
@ -303,7 +304,8 @@ class CPUEncoderDecoderModelRunner(
intermediate_tensors, intermediate_tensors,
} }
hidden_states = model_executable(**execute_model_kwargs) with set_forward_context(model_input.attn_metadata, self.vllm_config):
hidden_states = model_executable(**execute_model_kwargs)
# Compute the logits. # Compute the logits.
logits = self.model.compute_logits(hidden_states, logits = self.model.compute_logits(hidden_states,

View File

@ -10,6 +10,7 @@ from torch import nn
from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
@ -487,14 +488,15 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
multimodal_kwargs = MultiModalKwargs.as_kwargs( multimodal_kwargs = MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs, device=self.device) model_input.multi_modal_kwargs, device=self.device)
hidden_states = model_executable( with set_forward_context(model_input.attn_metadata, self.vllm_config):
input_ids=model_input.input_tokens, hidden_states = model_executable(
positions=model_input.input_positions, input_ids=model_input.input_tokens,
kv_caches=kv_caches, positions=model_input.input_positions,
attn_metadata=model_input.attn_metadata, kv_caches=kv_caches,
intermediate_tensors=intermediate_tensors, attn_metadata=model_input.attn_metadata,
**multimodal_kwargs, intermediate_tensors=intermediate_tensors,
) **multimodal_kwargs,
)
# Compute the logits. # Compute the logits.
logits = self.model.compute_logits(hidden_states, logits = self.model.compute_logits(hidden_states,