mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
Exception when duplicated autograd.Function name detected (#9351)
* Exception when duplicated autograd.Function name detected * reorder a bit for a bittle bit better perf * fix a bug in previous PR :( * correct the error message a bit
This commit is contained in:
parent
74eaaad768
commit
f05c285a58
2 changed files with 28 additions and 8 deletions
|
|
@ -24,14 +24,14 @@ class FP16OptimizerModifier(object):
|
|||
|
||||
def check_requirements(self, required_funcs, require_apex=False, require_torch_non_finite_check=False):
|
||||
try:
|
||||
if require_apex:
|
||||
if require_apex is True:
|
||||
import amp_C
|
||||
if require_torch_non_finite_check:
|
||||
if require_torch_non_finite_check is True:
|
||||
_ = torch._amp_foreach_non_finite_check_and_unscale_
|
||||
except Exception as _:
|
||||
return False
|
||||
|
||||
if not required_funcs:
|
||||
if required_funcs:
|
||||
for func_name in required_funcs:
|
||||
func = getattr(self._optimizer, func_name, None)
|
||||
if not func or not callable(func):
|
||||
|
|
|
|||
|
|
@ -185,6 +185,17 @@ def _post_process_after_export(exported_model, enable_custom_autograd_function,
|
|||
|
||||
|
||||
def _post_process_enabling_autograd_fallback(exported_model):
|
||||
registered_name_mappings = {}
|
||||
for kclass in torch.autograd.Function.__subclasses__():
|
||||
# Collect mapping of class names to full qualified class names.
|
||||
if kclass.__name__ not in registered_name_mappings:
|
||||
registered_name_mappings[kclass.__name__] = []
|
||||
full_qualified_name = kclass.__module__ + '.' + kclass.__qualname__
|
||||
registered_name_mappings[kclass.__name__].append(full_qualified_name)
|
||||
|
||||
# Register function with class names.
|
||||
register_torch_autograd_function(kclass.__name__, kclass)
|
||||
|
||||
index = 0
|
||||
for node in exported_model.graph.node:
|
||||
if node.domain == 'com.microsoft' and node.op_type in ["PythonOp"]:
|
||||
|
|
@ -192,13 +203,22 @@ def _post_process_enabling_autograd_fallback(exported_model):
|
|||
del node.output[:]
|
||||
node.output.append(output_names[0] + '_ctx')
|
||||
node.output.extend(output_names)
|
||||
for attr in node.attribute:
|
||||
if attr.name == 'name':
|
||||
kclass_name = attr.s.decode('utf-8') if isinstance(attr.s, bytes) else attr.s
|
||||
# If the duplicated function is used in ONNX graph, we will fail in case of a wrong function call.
|
||||
# Todo: remove this trick once exporter can support fully qualified name for PythonOp.
|
||||
if kclass_name in registered_name_mappings and len(registered_name_mappings[kclass_name]) > 1:
|
||||
error_msg = 'More than one torch.autograd.Function named {}, but probabbly in different namespace. ' \
|
||||
'The conflicting autograd.Functions are: {}. Currently torch exporter cannot ' \
|
||||
'differentiate them with full qualified name, so there is a risk exported PythonOp calls a ' \
|
||||
'wrong autograd.Function.'.format(kclass_name, ','.join(registered_name_mappings[kclass_name]))
|
||||
raise wrap_exception(ORTModuleONNXModelException, RuntimeError(error_msg))
|
||||
|
||||
break
|
||||
|
||||
if not node.name:
|
||||
node.name = node.op_type + "_id_" + str(index)
|
||||
index += 1
|
||||
|
||||
for kclass in torch.autograd.Function.__subclasses__():
|
||||
# Sometimes, we find the same functions multiple times, so we skip
|
||||
# registrations when their keys already exist.
|
||||
register_torch_autograd_function(kclass.__name__, kclass)
|
||||
|
||||
return exported_model
|
||||
|
|
|
|||
Loading…
Reference in a new issue