mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-07 00:13:17 +00:00
Del ort_model._modules to foward its accessing to torch_model._modules (#14563)
Missing '_modules' attribute in ORTModule will cause load_state_dict for wrapped_ortmodule fail. reference:https://github.com/microsoft/onnxruntime/pull/7847
This commit is contained in:
parent
8d87fdcfa1
commit
c49f250a14
2 changed files with 6 additions and 1 deletions
|
|
@ -104,6 +104,10 @@ class ORTModule(torch.nn.Module):
|
|||
# else, they will be assigned to self._torch_module.original_module instead.
|
||||
self._is_initialized = True
|
||||
|
||||
# del the ort._modules so that all reference to ort._modules will be forward to the underlying torch_model
|
||||
# through '__getattr__'
|
||||
del self._modules
|
||||
|
||||
# IMPORTANT: DO NOT add code here
|
||||
# This declaration is for automatic document generation purposes only
|
||||
# The actual forward implementation is bound during ORTModule initialization
|
||||
|
|
|
|||
|
|
@ -4192,7 +4192,8 @@ def test_load_state_dict_for_wrapped_ortmodule():
|
|||
x = torch.randn(N, D_in, device=device)
|
||||
_ = wrapper_module(x)
|
||||
|
||||
state_dict1 = wrapper_module.state_dict()
|
||||
# Must copy the state_dict or else they are sharing the same memory
|
||||
state_dict1 = copy.deepcopy(wrapper_module.state_dict())
|
||||
list(next(iter(state_dict1.items())))[1] += 10
|
||||
wrapper_module.load_state_dict(state_dict1)
|
||||
state_dict2 = wrapper_module.state_dict()
|
||||
|
|
|
|||
Loading…
Reference in a new issue