mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Don't duplicate logs in TensorBoard and handle --use_env (#11141)
This commit is contained in:
parent
9c9b8e707b
commit
dfed4ec263
2 changed files with 11 additions and 3 deletions
|
|
@ -604,9 +604,11 @@ class TensorBoardCallback(TrainerCallback):
|
|||
self.tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={})
|
||||
|
||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
if state.is_world_process_zero:
|
||||
if self.tb_writer is None:
|
||||
self._init_summary_writer(args)
|
||||
if not state.is_world_process_zero:
|
||||
return
|
||||
|
||||
if self.tb_writer is None:
|
||||
self._init_summary_writer(args)
|
||||
|
||||
if self.tb_writer is not None:
|
||||
logs = rewrite_logs(logs)
|
||||
|
|
|
|||
|
|
@ -531,6 +531,12 @@ class TrainingArguments:
|
|||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then).
|
||||
# This needs to happen before any call to self.device or self.n_gpu.
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != self.local_rank:
|
||||
self.local_rank = env_local_rank
|
||||
|
||||
# expand paths, if not os.makedirs("~/bar") will make directory
|
||||
# in the current directory instead of the actual home
|
||||
# see https://github.com/huggingface/transformers/issues/10628
|
||||
|
|
|
|||
Loading…
Reference in a new issue