mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Implemented add_pooling_layer arg to TFBertModel (#29603)
Implemented add_pooling_layer argument
This commit is contained in:
parent
50ec493363
commit
f1a565a39f
1 changed files with 2 additions and 2 deletions
|
|
@ -1182,10 +1182,10 @@ BERT_INPUTS_DOCSTRING = r"""
|
|||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class TFBertModel(TFBertPreTrainedModel):
|
||||
def __init__(self, config: BertConfig, *inputs, **kwargs):
|
||||
def __init__(self, config: BertConfig, add_pooling_layer: bool = True, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.bert = TFBertMainLayer(config, name="bert")
|
||||
self.bert = TFBertMainLayer(config, add_pooling_layer, name="bert")
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
|
|
|
|||
Loading…
Reference in a new issue