Implemented add_pooling_layer arg to TFBertModel (#29603)

Implemented add_pooling_layer argument
This commit is contained in:
tomigee 2024-03-12 09:01:55 -04:00 committed by GitHub
parent 50ec493363
commit f1a565a39f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1182,10 +1182,10 @@ BERT_INPUTS_DOCSTRING = r"""
BERT_START_DOCSTRING,
)
class TFBertModel(TFBertPreTrainedModel):
def __init__(self, config: BertConfig, *inputs, **kwargs):
def __init__(self, config: BertConfig, add_pooling_layer: bool = True, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.bert = TFBertMainLayer(config, name="bert")
self.bert = TFBertMainLayer(config, add_pooling_layer, name="bert")
@unpack_inputs
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))