diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 507e2cd3223fd..f493cc13ece2a 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -388,9 +388,9 @@ class Processor: eos_token_id = self.input_preprocessor.get_eos_token_id() - self._validate_model_inputs(processed_inputs) - encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) + self._validate_model_inputs(encoder_inputs, decoder_inputs) + # Mypy does not always properly infer the types of some elements of # discriminated unions of TypedDicts, because of how it handles # inheritance of TypedDict. If we explicitly extract the items we want @@ -458,9 +458,8 @@ class Processor: trace_headers=trace_headers, ) - def _validate_model_inputs(self, inputs: ProcessorInputs): - encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs) - + def _validate_model_inputs(self, encoder_inputs: Optional[SingletonInputs], + decoder_inputs: SingletonInputs): if encoder_inputs is not None: self._validate_model_input(encoder_inputs, prompt_type="encoder")