From 986dd5c5bfe97566ea3bc1db17982118ef09e920 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Tue, 17 May 2022 12:50:14 -0400 Subject: [PATCH] Fix style --- src/transformers/models/longformer/modeling_longformer.py | 5 ++++- .../models/longformer/modeling_tf_longformer.py | 6 +++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index d974c0f4d..30db98dea 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -2140,7 +2140,10 @@ class LongformerForTokenClassification(LongformerPreTrainedModel): checkpoint="brad1141/Longformer-finetuned-norm", output_type=LongformerTokenClassifierOutput, config_class=_CONFIG_FOR_DOC, - expected_output="['Evidence', 'Evidence', 'Evidence', 'Evidence', 'Evidence', 'Evidence', 'Evidence', 'Evidence', 'Evidence', 'Evidence', 'Evidence', 'Evidence']", + expected_output=( + "['Evidence', 'Evidence', 'Evidence', 'Evidence', 'Evidence', 'Evidence', 'Evidence', 'Evidence'," + " 'Evidence', 'Evidence', 'Evidence', 'Evidence']" + ), expected_loss=0.63, ) def forward( diff --git a/src/transformers/models/longformer/modeling_tf_longformer.py b/src/transformers/models/longformer/modeling_tf_longformer.py index 81809ce18..0dfd9c666 100644 --- a/src/transformers/models/longformer/modeling_tf_longformer.py +++ b/src/transformers/models/longformer/modeling_tf_longformer.py @@ -2591,7 +2591,11 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla checkpoint="hf-internal-testing/tiny-random-longformer", output_type=TFLongformerTokenClassifierOutput, config_class=_CONFIG_FOR_DOC, - expected_output="['LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1']", + expected_output=( + "['LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1'," + " 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1'," + " 'LABEL_1', 'LABEL_1']" + ), expected_loss=0.59, ) def call(