Del ort_model._modules to foward its accessing to torch_model._modules (#14563)

Missing '_modules' attribute in ORTModule will cause load_state_dict for
wrapped_ortmodule fail.

reference:https://github.com/microsoft/onnxruntime/pull/7847
This commit is contained in:
guyang3532 2023-03-03 10:12:37 +08:00 committed by GitHub
parent 8d87fdcfa1
commit c49f250a14
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 1 deletions

View file

@ -104,6 +104,10 @@ class ORTModule(torch.nn.Module):
# else, they will be assigned to self._torch_module.original_module instead.
self._is_initialized = True
# del the ort._modules so that all reference to ort._modules will be forward to the underlying torch_model
# through '__getattr__'
del self._modules
# IMPORTANT: DO NOT add code here
# This declaration is for automatic document generation purposes only
# The actual forward implementation is bound during ORTModule initialization

View file

@ -4192,7 +4192,8 @@ def test_load_state_dict_for_wrapped_ortmodule():
x = torch.randn(N, D_in, device=device)
_ = wrapper_module(x)
state_dict1 = wrapper_module.state_dict()
# Must copy the state_dict or else they are sharing the same memory
state_dict1 = copy.deepcopy(wrapper_module.state_dict())
list(next(iter(state_dict1.items())))[1] += 10
wrapper_module.load_state_dict(state_dict1)
state_dict2 = wrapper_module.state_dict()