mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
add tqdm, clean up logging
This commit is contained in:
parent
d4e3cf3520
commit
0b7a20c651
2 changed files with 2 additions and 4 deletions
|
|
@ -435,7 +435,6 @@ class BertForSequenceClassification(nn.Module):
|
|||
|
||||
def init_weights(m):
|
||||
if isinstance(m, (nn.Linear, nn.Embedding)):
|
||||
print("Initializing {}".format(m))
|
||||
# Slight difference here with the TF version which uses truncated_normal
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
m.weight.data.normal_(config.initializer_range)
|
||||
|
|
@ -481,7 +480,6 @@ class BertForQuestionAnswering(nn.Module):
|
|||
|
||||
def init_weights(m):
|
||||
if isinstance(m, (nn.Linear, nn.Embedding)):
|
||||
print("Initializing {}".format(m))
|
||||
# Slight difference here with the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
m.weight.data.normal_(config.initializer_range)
|
||||
|
|
|
|||
|
|
@ -912,9 +912,9 @@ def main():
|
|||
|
||||
model.eval()
|
||||
all_results = []
|
||||
logger.info("Start evaulating")
|
||||
logger.info("Start evaluating")
|
||||
#for input_ids, input_mask, segment_ids, label_ids, example_index in eval_dataloader:
|
||||
for input_ids, input_mask, segment_ids, example_index in eval_dataloader:
|
||||
for input_ids, input_mask, segment_ids, example_index in tqdm(eval_dataloader, descr="Evaluating"):
|
||||
if len(all_results) % 1000 == 0:
|
||||
logger.info("Processing example: %d" % (len(all_results)))
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue