diff --git a/orttraining/orttraining/python/training/ortmodule/ortmodule.py b/orttraining/orttraining/python/training/ortmodule/ortmodule.py index 18000e0462..90f88459fc 100644 --- a/orttraining/orttraining/python/training/ortmodule/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule/ortmodule.py @@ -104,6 +104,10 @@ class ORTModule(torch.nn.Module): # else, they will be assigned to self._torch_module.original_module instead. self._is_initialized = True + # del the ort._modules so that all reference to ort._modules will be forward to the underlying torch_model + # through '__getattr__' + del self._modules + # IMPORTANT: DO NOT add code here # This declaration is for automatic document generation purposes only # The actual forward implementation is bound during ORTModule initialization diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index f668776d02..f39d29a178 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -4192,7 +4192,8 @@ def test_load_state_dict_for_wrapped_ortmodule(): x = torch.randn(N, D_in, device=device) _ = wrapper_module(x) - state_dict1 = wrapper_module.state_dict() + # Must copy the state_dict or else they are sharing the same memory + state_dict1 = copy.deepcopy(wrapper_module.state_dict()) list(next(iter(state_dict1.items())))[1] += 10 wrapper_module.load_state_dict(state_dict1) state_dict2 = wrapper_module.state_dict()