diff --git a/orttraining/orttraining/python/training/ortmodule/__init__.py b/orttraining/orttraining/python/training/ortmodule/__init__.py index 258fc68906..477e085cde 100644 --- a/orttraining/orttraining/python/training/ortmodule/__init__.py +++ b/orttraining/orttraining/python/training/ortmodule/__init__.py @@ -5,6 +5,7 @@ import os import sys +import warnings import torch from packaging import version @@ -16,6 +17,21 @@ from ._fallback import (_FallbackPolicy, wrap_exception) from .torch_cpp_extensions import is_installed as is_torch_cpp_extensions_installed + +def _defined_from_envvar(name, default_value, warn=True): + new_value = os.getenv(name, None) + if new_value is None: + return default_value + try: + new_value = type(default_value)(new_value) + except (TypeError, ValueError) as e: + if warn: + warnings.warn( + "Unable to overwrite constant %r due to %r." % (name, e)) + return default_value + return new_value + + ################################################################################ # All global constant goes here, before ORTModule is imported ################## # NOTE: To *change* values in runtime, import onnxruntime.training.ortmodule and diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 47b3db486d..39d24c904e 100644 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -93,6 +93,11 @@ class GraphExecutionManager(GraphExecutionInterface): self._graph_initializer_names_to_train = None self._graph_initializers = None + # Update constant ONNX_OPSET_VERSION with env var ORTMODULE_ONNX_OPSET_VERSION + # if defined. + ortmodule.ONNX_OPSET_VERSION = ortmodule._defined_from_envvar( + 'ORTMODULE_ONNX_OPSET_VERSION', ortmodule.ONNX_OPSET_VERSION, warn=True) + # TrainingAgent or InferenceAgent self._execution_agent = None diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index e2526d22c6..f6c579e353 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -20,7 +20,6 @@ import tempfile import os import pickle from distutils.version import LooseVersion - from onnxruntime.training.ortmodule._custom_gradient_registry import register_gradient from onnxruntime.training.ortmodule import (ORTModule, _utils, @@ -4268,6 +4267,76 @@ def test_tanh_grad(): _test_helpers.assert_values_are_close(ort_x.grad, pt_x.grad) _test_helpers.assert_values_are_close(ort_loss, pt_loss) + +def test__defined_from_envvar(): + from onnxruntime.training import ortmodule + os.environ['DUMMY_ORTMODULE'] = '15' + assert ortmodule._defined_from_envvar('DUMMY_ORTMODULE', 14) == 15 + os.environ['DUMMY_ORTMODULE'] = '15j' + with warnings.catch_warnings(record=True) as w: + assert ortmodule._defined_from_envvar('DUMMY_ORTMODULE', 14) == 14 + assert len(w) == 1 + assert issubclass(w[-1].category, UserWarning) + assert "Unable to overwrite constant" in str(w[-1].message) + del os.environ['DUMMY_ORTMODULE'] + + +def test_sigmoid_grad_opset13(): + class NeuralNetSigmoid(torch.nn.Module): + def __init__(self, input_size, hidden_size, num_classes): + super(NeuralNetSigmoid, self).__init__() + + self.fc1 = torch.nn.Linear(input_size, hidden_size) + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, input1): + out = self.fc1(input1) + out = self.sigmoid(out) + return out + + def run_step(model, x): + prediction = model(x) + loss = prediction.sum() + loss.backward() + return prediction, loss + device = 'cuda' + + N, D_in, H, D_out = 120, 15360, 500, 15360 + pt_model = NeuralNetSigmoid(D_in, H, D_out).to(device) + + from onnxruntime.training import ortmodule + old_opst_cst = ortmodule.ONNX_OPSET_VERSION + old_opset = os.getenv("ORTMODULE_ONNX_OPSET_VERSION", None) + os.environ["ORTMODULE_ONNX_OPSET_VERSION"] = '13' + assert ortmodule.ONNX_OPSET_VERSION == 12 + + ort_model = ORTModule(copy.deepcopy(pt_model)) + + for step in range(2): + pt_x = torch.randn(N, D_in, device=device, requires_grad=True) + ort_x = copy.deepcopy(pt_x) + ort_prediction, ort_loss = run_step(ort_model, ort_x) + pt_prediction, pt_loss = run_step(pt_model, pt_x) + if step == 0: + model_onx = ort_model._torch_module._execution_manager._training_manager._onnx_models + for name in ['exported_model', 'optimized_model', 'optimized_pre_grad_model']: + onx = getattr(model_onx, name) + opv = None + for op in onx.opset_import: + if op.domain == '': + opv = op.version + assert opv == 13 + _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) + _test_helpers.assert_values_are_close(ort_x.grad, pt_x.grad) + _test_helpers.assert_values_are_close(ort_loss, pt_loss) + + if old_opset is None: + del os.environ["ORTMODULE_ONNX_OPSET_VERSION"] + else: + os.environ["ORTMODULE_ONNX_OPSET_VERSION"] = old_opset + assert ortmodule.ONNX_OPSET_VERSION == 13 + ortmodule.ONNX_OPSET_VERSION = old_opst_cst + @pytest.mark.parametrize("opset_version", [12, 13]) def test_opset_version_change(opset_version): device = 'cuda'