mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-21 21:52:11 +00:00
resolve review comments
This commit is contained in:
parent
89902c2519
commit
89e5b3a24f
1 changed files with 3 additions and 3 deletions
|
|
@ -646,6 +646,9 @@ class ORTTrainer(object):
|
|||
self.options.graph_transformer.transformer_layer_recompute):
|
||||
session_options.execution_order = ort.ExecutionOrder.PRIORITY_BASED
|
||||
|
||||
# old ort session may already exists and occupies GPU memory when creating new session, this may cause OOM error.
|
||||
# for example, load_state_dict will be called before returing the function, and it calls _init_session again
|
||||
del self._training_session
|
||||
# TrainingSession
|
||||
self._training_session = ort.TrainingSession(self._onnx_model.SerializeToString(),
|
||||
ort_parameters,
|
||||
|
|
@ -685,9 +688,6 @@ class ORTTrainer(object):
|
|||
if self.options.utils.run_symbolic_shape_infer:
|
||||
self._onnx_model = SymbolicShapeInference.infer_shapes(self._onnx_model, auto_merge=True, guess_output_rank=True)
|
||||
|
||||
# old ort session may already exists and occupies GPU memory when creating new session, this may cause OOM error.
|
||||
# for example, load_state_dict will be called before returing the function, and it calls _init_session again
|
||||
del self._training_session
|
||||
# Create training session used by train_step
|
||||
self._create_ort_training_session()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue