From 3bdb10d5ca4f258ec444863bcd5e839eeac5c238 Mon Sep 17 00:00:00 2001 From: jingyanwangms <47403504+jingyanwangms@users.noreply.github.com> Date: Thu, 22 Feb 2024 10:56:25 -0800 Subject: [PATCH] Move import to when needed to avoid circular dependency error (#19579) ### Description Move import to when needed to avoid circular dependency error ### Motivation and Context Fixes dependency error described here: https://github.com/microsoft/DeepSpeed/issues/5140 --------- Co-authored-by: Thiago Crepaldi --- .../python/training/ortmodule/_graph_execution_manager.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 779b6bfe50..fda6e345da 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -20,7 +20,6 @@ import onnxruntime from onnxruntime.capi import _pybind_state as C from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference from onnxruntime.training.utils import ORTModelInputOutputSchemaType, PTable, onnx_dtype_to_pytorch_dtype -from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3 from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils from ._fallback import ( @@ -143,6 +142,9 @@ class GraphExecutionManager(GraphExecutionInterface): self._zero_stage3_param_map = {} if self._runtime_options.enable_zero_stage3_support: + # Move import to here to avoid circular dependency error + from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3 # type: ignore[import] + # Cannot toggle feature enabling/disabling after the first time enabled. configure_ort_compatible_zero_stage3(debug=False, stats_output_dir="ort_output", stats_overwrite=True)