Use ORTMODULE_ONNX_OPSET_VERSION to modify the opset version in OrtModule (#9529)

* Use environment variable to change the ONNX opset in ORTModule
* overwrite ONNX_OPSET_VERSION
* store envvar in module constant
This commit is contained in:
Xavier Dupré 2021-11-08 17:03:16 +01:00 committed by GitHub
parent 1151c661eb
commit 7e207ba3be
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 91 additions and 1 deletions

View file

@ -5,6 +5,7 @@
import os
import sys
import warnings
import torch
from packaging import version
@ -16,6 +17,21 @@ from ._fallback import (_FallbackPolicy,
wrap_exception)
from .torch_cpp_extensions import is_installed as is_torch_cpp_extensions_installed
def _defined_from_envvar(name, default_value, warn=True):
new_value = os.getenv(name, None)
if new_value is None:
return default_value
try:
new_value = type(default_value)(new_value)
except (TypeError, ValueError) as e:
if warn:
warnings.warn(
"Unable to overwrite constant %r due to %r." % (name, e))
return default_value
return new_value
################################################################################
# All global constant goes here, before ORTModule is imported ##################
# NOTE: To *change* values in runtime, import onnxruntime.training.ortmodule and

View file

@ -93,6 +93,11 @@ class GraphExecutionManager(GraphExecutionInterface):
self._graph_initializer_names_to_train = None
self._graph_initializers = None
# Update constant ONNX_OPSET_VERSION with env var ORTMODULE_ONNX_OPSET_VERSION
# if defined.
ortmodule.ONNX_OPSET_VERSION = ortmodule._defined_from_envvar(
'ORTMODULE_ONNX_OPSET_VERSION', ortmodule.ONNX_OPSET_VERSION, warn=True)
# TrainingAgent or InferenceAgent
self._execution_agent = None

View file

@ -20,7 +20,6 @@ import tempfile
import os
import pickle
from distutils.version import LooseVersion
from onnxruntime.training.ortmodule._custom_gradient_registry import register_gradient
from onnxruntime.training.ortmodule import (ORTModule,
_utils,
@ -4268,6 +4267,76 @@ def test_tanh_grad():
_test_helpers.assert_values_are_close(ort_x.grad, pt_x.grad)
_test_helpers.assert_values_are_close(ort_loss, pt_loss)
def test__defined_from_envvar():
from onnxruntime.training import ortmodule
os.environ['DUMMY_ORTMODULE'] = '15'
assert ortmodule._defined_from_envvar('DUMMY_ORTMODULE', 14) == 15
os.environ['DUMMY_ORTMODULE'] = '15j'
with warnings.catch_warnings(record=True) as w:
assert ortmodule._defined_from_envvar('DUMMY_ORTMODULE', 14) == 14
assert len(w) == 1
assert issubclass(w[-1].category, UserWarning)
assert "Unable to overwrite constant" in str(w[-1].message)
del os.environ['DUMMY_ORTMODULE']
def test_sigmoid_grad_opset13():
class NeuralNetSigmoid(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(NeuralNetSigmoid, self).__init__()
self.fc1 = torch.nn.Linear(input_size, hidden_size)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, input1):
out = self.fc1(input1)
out = self.sigmoid(out)
return out
def run_step(model, x):
prediction = model(x)
loss = prediction.sum()
loss.backward()
return prediction, loss
device = 'cuda'
N, D_in, H, D_out = 120, 15360, 500, 15360
pt_model = NeuralNetSigmoid(D_in, H, D_out).to(device)
from onnxruntime.training import ortmodule
old_opst_cst = ortmodule.ONNX_OPSET_VERSION
old_opset = os.getenv("ORTMODULE_ONNX_OPSET_VERSION", None)
os.environ["ORTMODULE_ONNX_OPSET_VERSION"] = '13'
assert ortmodule.ONNX_OPSET_VERSION == 12
ort_model = ORTModule(copy.deepcopy(pt_model))
for step in range(2):
pt_x = torch.randn(N, D_in, device=device, requires_grad=True)
ort_x = copy.deepcopy(pt_x)
ort_prediction, ort_loss = run_step(ort_model, ort_x)
pt_prediction, pt_loss = run_step(pt_model, pt_x)
if step == 0:
model_onx = ort_model._torch_module._execution_manager._training_manager._onnx_models
for name in ['exported_model', 'optimized_model', 'optimized_pre_grad_model']:
onx = getattr(model_onx, name)
opv = None
for op in onx.opset_import:
if op.domain == '':
opv = op.version
assert opv == 13
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
_test_helpers.assert_values_are_close(ort_x.grad, pt_x.grad)
_test_helpers.assert_values_are_close(ort_loss, pt_loss)
if old_opset is None:
del os.environ["ORTMODULE_ONNX_OPSET_VERSION"]
else:
os.environ["ORTMODULE_ONNX_OPSET_VERSION"] = old_opset
assert ortmodule.ONNX_OPSET_VERSION == 13
ortmodule.ONNX_OPSET_VERSION = old_opst_cst
@pytest.mark.parametrize("opset_version", [12, 13])
def test_opset_version_change(opset_version):
device = 'cuda'