diff --git a/src/transformers/models/bert/modeling_tf_bert.py b/src/transformers/models/bert/modeling_tf_bert.py index 7fe89e43e..cc1218bbe 100644 --- a/src/transformers/models/bert/modeling_tf_bert.py +++ b/src/transformers/models/bert/modeling_tf_bert.py @@ -1182,10 +1182,10 @@ BERT_INPUTS_DOCSTRING = r""" BERT_START_DOCSTRING, ) class TFBertModel(TFBertPreTrainedModel): - def __init__(self, config: BertConfig, *inputs, **kwargs): + def __init__(self, config: BertConfig, add_pooling_layer: bool = True, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) - self.bert = TFBertMainLayer(config, name="bert") + self.bert = TFBertMainLayer(config, add_pooling_layer, name="bert") @unpack_inputs @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))