diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index cd93f0ef1e31..9c1c05320cf3 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -174,12 +174,15 @@ class TransformerBlock(torch.nn.Module): def __init__( self, config: GptOssConfig, + cache_config: CacheConfig, quant_config: QuantizationConfig, prefix: str = "", ): super().__init__() self.layer_idx = extract_layer_index(prefix) - self.attn = OAIAttention(config, prefix=f"{prefix}.attn") + self.attn = OAIAttention(config, + prefix=f"{prefix}.attn", + cache_config=cache_config) self.mlp = MLPBlock(config, self.layer_idx, quant_config=quant_config, @@ -203,6 +206,7 @@ class GptOssModel(nn.Module): ): super().__init__() self.config = vllm_config.model_config.hf_config + self.cache_config = vllm_config.cache_config self.quant_config = vllm_config.quant_config self.parallel_config = vllm_config.parallel_config self.config.hidden_size = self.config.hidden_size @@ -213,6 +217,7 @@ class GptOssModel(nn.Module): self.layers = torch.nn.ModuleList([ TransformerBlock( self.config, + cache_config=self.cache_config, quant_config=self.quant_config, prefix=maybe_prefix(prefix, f"block.{layer_idx}"), ) for layer_idx in range(self.config.num_hidden_layers)