resolve review comments

This commit is contained in:
zhijxu 2020-11-12 16:07:05 +08:00 committed by zhijxu-MS
parent 89902c2519
commit 89e5b3a24f

View file

@ -646,6 +646,9 @@ class ORTTrainer(object):
self.options.graph_transformer.transformer_layer_recompute):
session_options.execution_order = ort.ExecutionOrder.PRIORITY_BASED
# old ort session may already exists and occupies GPU memory when creating new session, this may cause OOM error.
# for example, load_state_dict will be called before returing the function, and it calls _init_session again
del self._training_session
# TrainingSession
self._training_session = ort.TrainingSession(self._onnx_model.SerializeToString(),
ort_parameters,
@ -685,9 +688,6 @@ class ORTTrainer(object):
if self.options.utils.run_symbolic_shape_infer:
self._onnx_model = SymbolicShapeInference.infer_shapes(self._onnx_model, auto_merge=True, guess_output_rank=True)
# old ort session may already exists and occupies GPU memory when creating new session, this may cause OOM error.
# for example, load_state_dict will be called before returing the function, and it calls _init_session again
del self._training_session
# Create training session used by train_step
self._create_ort_training_session()