mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-19 21:32:23 +00:00
Use ORTMODULE_ONNX_OPSET_VERSION to modify the opset version in OrtModule (#9529)
* Use environment variable to change the ONNX opset in ORTModule * overwrite ONNX_OPSET_VERSION * store envvar in module constant
This commit is contained in:
parent
1151c661eb
commit
7e207ba3be
3 changed files with 91 additions and 1 deletions
|
|
@ -5,6 +5,7 @@
|
|||
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
|
|
@ -16,6 +17,21 @@ from ._fallback import (_FallbackPolicy,
|
|||
wrap_exception)
|
||||
from .torch_cpp_extensions import is_installed as is_torch_cpp_extensions_installed
|
||||
|
||||
|
||||
def _defined_from_envvar(name, default_value, warn=True):
|
||||
new_value = os.getenv(name, None)
|
||||
if new_value is None:
|
||||
return default_value
|
||||
try:
|
||||
new_value = type(default_value)(new_value)
|
||||
except (TypeError, ValueError) as e:
|
||||
if warn:
|
||||
warnings.warn(
|
||||
"Unable to overwrite constant %r due to %r." % (name, e))
|
||||
return default_value
|
||||
return new_value
|
||||
|
||||
|
||||
################################################################################
|
||||
# All global constant goes here, before ORTModule is imported ##################
|
||||
# NOTE: To *change* values in runtime, import onnxruntime.training.ortmodule and
|
||||
|
|
|
|||
|
|
@ -93,6 +93,11 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
self._graph_initializer_names_to_train = None
|
||||
self._graph_initializers = None
|
||||
|
||||
# Update constant ONNX_OPSET_VERSION with env var ORTMODULE_ONNX_OPSET_VERSION
|
||||
# if defined.
|
||||
ortmodule.ONNX_OPSET_VERSION = ortmodule._defined_from_envvar(
|
||||
'ORTMODULE_ONNX_OPSET_VERSION', ortmodule.ONNX_OPSET_VERSION, warn=True)
|
||||
|
||||
# TrainingAgent or InferenceAgent
|
||||
self._execution_agent = None
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ import tempfile
|
|||
import os
|
||||
import pickle
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
from onnxruntime.training.ortmodule._custom_gradient_registry import register_gradient
|
||||
from onnxruntime.training.ortmodule import (ORTModule,
|
||||
_utils,
|
||||
|
|
@ -4268,6 +4267,76 @@ def test_tanh_grad():
|
|||
_test_helpers.assert_values_are_close(ort_x.grad, pt_x.grad)
|
||||
_test_helpers.assert_values_are_close(ort_loss, pt_loss)
|
||||
|
||||
|
||||
def test__defined_from_envvar():
|
||||
from onnxruntime.training import ortmodule
|
||||
os.environ['DUMMY_ORTMODULE'] = '15'
|
||||
assert ortmodule._defined_from_envvar('DUMMY_ORTMODULE', 14) == 15
|
||||
os.environ['DUMMY_ORTMODULE'] = '15j'
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
assert ortmodule._defined_from_envvar('DUMMY_ORTMODULE', 14) == 14
|
||||
assert len(w) == 1
|
||||
assert issubclass(w[-1].category, UserWarning)
|
||||
assert "Unable to overwrite constant" in str(w[-1].message)
|
||||
del os.environ['DUMMY_ORTMODULE']
|
||||
|
||||
|
||||
def test_sigmoid_grad_opset13():
|
||||
class NeuralNetSigmoid(torch.nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_classes):
|
||||
super(NeuralNetSigmoid, self).__init__()
|
||||
|
||||
self.fc1 = torch.nn.Linear(input_size, hidden_size)
|
||||
self.sigmoid = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, input1):
|
||||
out = self.fc1(input1)
|
||||
out = self.sigmoid(out)
|
||||
return out
|
||||
|
||||
def run_step(model, x):
|
||||
prediction = model(x)
|
||||
loss = prediction.sum()
|
||||
loss.backward()
|
||||
return prediction, loss
|
||||
device = 'cuda'
|
||||
|
||||
N, D_in, H, D_out = 120, 15360, 500, 15360
|
||||
pt_model = NeuralNetSigmoid(D_in, H, D_out).to(device)
|
||||
|
||||
from onnxruntime.training import ortmodule
|
||||
old_opst_cst = ortmodule.ONNX_OPSET_VERSION
|
||||
old_opset = os.getenv("ORTMODULE_ONNX_OPSET_VERSION", None)
|
||||
os.environ["ORTMODULE_ONNX_OPSET_VERSION"] = '13'
|
||||
assert ortmodule.ONNX_OPSET_VERSION == 12
|
||||
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
|
||||
for step in range(2):
|
||||
pt_x = torch.randn(N, D_in, device=device, requires_grad=True)
|
||||
ort_x = copy.deepcopy(pt_x)
|
||||
ort_prediction, ort_loss = run_step(ort_model, ort_x)
|
||||
pt_prediction, pt_loss = run_step(pt_model, pt_x)
|
||||
if step == 0:
|
||||
model_onx = ort_model._torch_module._execution_manager._training_manager._onnx_models
|
||||
for name in ['exported_model', 'optimized_model', 'optimized_pre_grad_model']:
|
||||
onx = getattr(model_onx, name)
|
||||
opv = None
|
||||
for op in onx.opset_import:
|
||||
if op.domain == '':
|
||||
opv = op.version
|
||||
assert opv == 13
|
||||
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
|
||||
_test_helpers.assert_values_are_close(ort_x.grad, pt_x.grad)
|
||||
_test_helpers.assert_values_are_close(ort_loss, pt_loss)
|
||||
|
||||
if old_opset is None:
|
||||
del os.environ["ORTMODULE_ONNX_OPSET_VERSION"]
|
||||
else:
|
||||
os.environ["ORTMODULE_ONNX_OPSET_VERSION"] = old_opset
|
||||
assert ortmodule.ONNX_OPSET_VERSION == 13
|
||||
ortmodule.ONNX_OPSET_VERSION = old_opst_cst
|
||||
|
||||
@pytest.mark.parametrize("opset_version", [12, 13])
|
||||
def test_opset_version_change(opset_version):
|
||||
device = 'cuda'
|
||||
|
|
|
|||
Loading…
Reference in a new issue