From f05c285a58a98181773ce32cf3dcff64b54cb141 Mon Sep 17 00:00:00 2001 From: pengwa Date: Fri, 15 Oct 2021 12:23:13 +0800 Subject: [PATCH] Exception when duplicated autograd.Function name detected (#9351) * Exception when duplicated autograd.Function name detected * reorder a bit for a bittle bit better perf * fix a bug in previous PR :( * correct the error message a bit --- .../python/training/optim/_modifier.py | 6 ++-- .../_custom_autograd_function_exporter.py | 30 +++++++++++++++---- 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/orttraining/orttraining/python/training/optim/_modifier.py b/orttraining/orttraining/python/training/optim/_modifier.py index 5c0ce9f7b6..491d13a49f 100644 --- a/orttraining/orttraining/python/training/optim/_modifier.py +++ b/orttraining/orttraining/python/training/optim/_modifier.py @@ -24,14 +24,14 @@ class FP16OptimizerModifier(object): def check_requirements(self, required_funcs, require_apex=False, require_torch_non_finite_check=False): try: - if require_apex: + if require_apex is True: import amp_C - if require_torch_non_finite_check: + if require_torch_non_finite_check is True: _ = torch._amp_foreach_non_finite_check_and_unscale_ except Exception as _: return False - if not required_funcs: + if required_funcs: for func_name in required_funcs: func = getattr(self._optimizer, func_name, None) if not func or not callable(func): diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py index 946d47d560..e1120d08de 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py @@ -185,6 +185,17 @@ def _post_process_after_export(exported_model, enable_custom_autograd_function, def _post_process_enabling_autograd_fallback(exported_model): + registered_name_mappings = {} + for kclass in torch.autograd.Function.__subclasses__(): + # Collect mapping of class names to full qualified class names. + if kclass.__name__ not in registered_name_mappings: + registered_name_mappings[kclass.__name__] = [] + full_qualified_name = kclass.__module__ + '.' + kclass.__qualname__ + registered_name_mappings[kclass.__name__].append(full_qualified_name) + + # Register function with class names. + register_torch_autograd_function(kclass.__name__, kclass) + index = 0 for node in exported_model.graph.node: if node.domain == 'com.microsoft' and node.op_type in ["PythonOp"]: @@ -192,13 +203,22 @@ def _post_process_enabling_autograd_fallback(exported_model): del node.output[:] node.output.append(output_names[0] + '_ctx') node.output.extend(output_names) + for attr in node.attribute: + if attr.name == 'name': + kclass_name = attr.s.decode('utf-8') if isinstance(attr.s, bytes) else attr.s + # If the duplicated function is used in ONNX graph, we will fail in case of a wrong function call. + # Todo: remove this trick once exporter can support fully qualified name for PythonOp. + if kclass_name in registered_name_mappings and len(registered_name_mappings[kclass_name]) > 1: + error_msg = 'More than one torch.autograd.Function named {}, but probabbly in different namespace. ' \ + 'The conflicting autograd.Functions are: {}. Currently torch exporter cannot ' \ + 'differentiate them with full qualified name, so there is a risk exported PythonOp calls a ' \ + 'wrong autograd.Function.'.format(kclass_name, ','.join(registered_name_mappings[kclass_name])) + raise wrap_exception(ORTModuleONNXModelException, RuntimeError(error_msg)) + + break + if not node.name: node.name = node.op_type + "_id_" + str(index) index += 1 - for kclass in torch.autograd.Function.__subclasses__(): - # Sometimes, we find the same functions multiple times, so we skip - # registrations when their keys already exist. - register_torch_autograd_function(kclass.__name__, kclass) - return exported_model