mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
add model.zero_grad()
This commit is contained in:
parent
a4086c5de5
commit
cb76c1ddd3
2 changed files with 2 additions and 0 deletions
|
|
@ -531,6 +531,7 @@ def main():
|
|||
loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
|
||||
total_tr_loss += loss.item()
|
||||
nb_tr_examples += input_ids.size(0)
|
||||
model.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
global_step += 1
|
||||
|
|
|
|||
|
|
@ -856,6 +856,7 @@ def main():
|
|||
|
||||
logger.info("HHHHH Forward")
|
||||
loss, _ = model(input_ids, segment_ids, input_mask, start_positions, end_positions)
|
||||
model.zero_grad()
|
||||
logger.info("HHHHH Backward")
|
||||
loss.backward()
|
||||
logger.info("HHHHH Loading data")
|
||||
|
|
|
|||
Loading…
Reference in a new issue