diff --git a/vllm/v1/worker/gpu/sampler.py b/vllm/v1/worker/gpu/sampler.py index d03309abca61a..4b4a250ebaa19 100644 --- a/vllm/v1/worker/gpu/sampler.py +++ b/vllm/v1/worker/gpu/sampler.py @@ -17,10 +17,10 @@ class Sampler(nn.Module): def __init__( self, - logprobs_mode: LogprobsMode = LogprobsMode.PROCESSED_LOGPROBS, + logprobs_mode: LogprobsMode = "processed_logprobs", ): super().__init__() - assert logprobs_mode == LogprobsMode.PROCESSED_LOGPROBS + assert logprobs_mode == "processed_logprobs" self.logprobs_mode = logprobs_mode def forward(