From 89e5b3a24f29ffdca490147ad6aef882cb74e52f Mon Sep 17 00:00:00 2001 From: zhijxu Date: Thu, 12 Nov 2020 16:07:05 +0800 Subject: [PATCH] resolve review comments --- orttraining/orttraining/python/training/orttrainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/orttraining/orttraining/python/training/orttrainer.py b/orttraining/orttraining/python/training/orttrainer.py index 520bcd1a33..1cb243b788 100644 --- a/orttraining/orttraining/python/training/orttrainer.py +++ b/orttraining/orttraining/python/training/orttrainer.py @@ -646,6 +646,9 @@ class ORTTrainer(object): self.options.graph_transformer.transformer_layer_recompute): session_options.execution_order = ort.ExecutionOrder.PRIORITY_BASED + # old ort session may already exists and occupies GPU memory when creating new session, this may cause OOM error. + # for example, load_state_dict will be called before returing the function, and it calls _init_session again + del self._training_session # TrainingSession self._training_session = ort.TrainingSession(self._onnx_model.SerializeToString(), ort_parameters, @@ -685,9 +688,6 @@ class ORTTrainer(object): if self.options.utils.run_symbolic_shape_infer: self._onnx_model = SymbolicShapeInference.infer_shapes(self._onnx_model, auto_merge=True, guess_output_rank=True) - # old ort session may already exists and occupies GPU memory when creating new session, this may cause OOM error. - # for example, load_state_dict will be called before returing the function, and it calls _init_session again - del self._training_session # Create training session used by train_step self._create_ort_training_session()