Exception when duplicated autograd.Function name detected (#9351)

* Exception when duplicated autograd.Function name detected

* reorder a bit for a bittle bit better perf

* fix a bug in previous PR :(

* correct the error message a bit
This commit is contained in:
pengwa 2021-10-15 12:23:13 +08:00 committed by GitHub
parent 74eaaad768
commit f05c285a58
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 8 deletions

View file

@ -24,14 +24,14 @@ class FP16OptimizerModifier(object):
def check_requirements(self, required_funcs, require_apex=False, require_torch_non_finite_check=False):
try:
if require_apex:
if require_apex is True:
import amp_C
if require_torch_non_finite_check:
if require_torch_non_finite_check is True:
_ = torch._amp_foreach_non_finite_check_and_unscale_
except Exception as _:
return False
if not required_funcs:
if required_funcs:
for func_name in required_funcs:
func = getattr(self._optimizer, func_name, None)
if not func or not callable(func):

View file

@ -185,6 +185,17 @@ def _post_process_after_export(exported_model, enable_custom_autograd_function,
def _post_process_enabling_autograd_fallback(exported_model):
registered_name_mappings = {}
for kclass in torch.autograd.Function.__subclasses__():
# Collect mapping of class names to full qualified class names.
if kclass.__name__ not in registered_name_mappings:
registered_name_mappings[kclass.__name__] = []
full_qualified_name = kclass.__module__ + '.' + kclass.__qualname__
registered_name_mappings[kclass.__name__].append(full_qualified_name)
# Register function with class names.
register_torch_autograd_function(kclass.__name__, kclass)
index = 0
for node in exported_model.graph.node:
if node.domain == 'com.microsoft' and node.op_type in ["PythonOp"]:
@ -192,13 +203,22 @@ def _post_process_enabling_autograd_fallback(exported_model):
del node.output[:]
node.output.append(output_names[0] + '_ctx')
node.output.extend(output_names)
for attr in node.attribute:
if attr.name == 'name':
kclass_name = attr.s.decode('utf-8') if isinstance(attr.s, bytes) else attr.s
# If the duplicated function is used in ONNX graph, we will fail in case of a wrong function call.
# Todo: remove this trick once exporter can support fully qualified name for PythonOp.
if kclass_name in registered_name_mappings and len(registered_name_mappings[kclass_name]) > 1:
error_msg = 'More than one torch.autograd.Function named {}, but probabbly in different namespace. ' \
'The conflicting autograd.Functions are: {}. Currently torch exporter cannot ' \
'differentiate them with full qualified name, so there is a risk exported PythonOp calls a ' \
'wrong autograd.Function.'.format(kclass_name, ','.join(registered_name_mappings[kclass_name]))
raise wrap_exception(ORTModuleONNXModelException, RuntimeError(error_msg))
break
if not node.name:
node.name = node.op_type + "_id_" + str(index)
index += 1
for kclass in torch.autograd.Function.__subclasses__():
# Sometimes, we find the same functions multiple times, so we skip
# registrations when their keys already exist.
register_torch_autograd_function(kclass.__name__, kclass)
return exported_model