Remove the use of eval in test code (#15097)

### Description

Remove the use of `eval` in test code so we don't (1) use eval and (2)
create "unused" local vars that ruff will remove. Predecessor to #15085
This commit is contained in:
Justin Chu 2023-03-20 09:43:56 -07:00 committed by GitHub
parent c964da7ea2
commit bdd7bd084c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 121 additions and 213 deletions

View file

@ -3,12 +3,12 @@
# pylint: disable=missing-docstring, too-many-public-methods, no-member
import operator
import unittest
import numpy as np
import onnxruntime_pybind11_state as torch_ort
import torch
from parameterized import parameterized, param
from parameterized import param, parameterized
class OrtOpTests(unittest.TestCase):
@ -570,11 +570,6 @@ class OrtOpTests(unittest.TestCase):
# for floor and erf, the ort produces a roundoff error for NaN input, but cpu keeps it a NaN.
# Thus, we use nan_to_num to ensure actual numbers are passed in.
# As many of the following use eval and make it appear to pylint that there are many unused variables,
# we disable those warnings
# pylint: disable=eval-used, unused-argument, unused-variable, no-self-argument,
ops = [
["abs", torch.tensor([-1, -2, 3, -6, -7])],
["acos"],
@ -616,20 +611,20 @@ class OrtOpTests(unittest.TestCase):
# @parameterized.expand generate test methods for ops and using name_func we renaming the test to be test_{ops}
@parameterized.expand(ops, name_func=rename_func)
def test_op(self, test_name, tensor_test=torch.rand(6)):
# compile eval- creates a code object that evaluates the operator (for example torch.abs(tensor_test)) and returns its result.
cpu_result = eval(compile("torch." + test_name + "(tensor_test)", "<string>", "eval"))
ort_result = eval(compile("torch." + test_name + "(tensor_test.to(self.get_device()))", "<string>", "eval"))
cpu_result = getattr(torch, test_name)(tensor_test)
ort_result = getattr(torch, test_name)(tensor_test.to(self.get_device()))
assert torch.allclose(cpu_result, ort_result.cpu(), equal_nan=True)
@parameterized.expand(ops, name_func=rename_func)
def test_op_(self, test_name, tensor_test=torch.rand(6)):
def test_op_inplace(self, test_name, tensor_test=torch.rand(6)):
device = self.get_device()
cpu_tensor = tensor_test
ort_tensor = cpu_tensor.to(device)
eval(compile("torch." + test_name + "_(cpu_tensor)", "<string>", "eval"))
eval(compile("torch." + test_name + "_(ort_tensor)", "<string>", "eval"))
getattr(torch, test_name + "_")(cpu_tensor)
getattr(torch, test_name + "_")(ort_tensor)
assert torch.allclose(cpu_tensor, ort_tensor.cpu(), equal_nan=True)
@ -648,10 +643,8 @@ class OrtOpTests(unittest.TestCase):
cpu_out_tensor = torch.tensor([], dtype=tensor_test.dtype)
ort_out_tensor = cpu_out_tensor.to(device)
st_cpu = f"torch.{test_name}(cpu_tensor, out=cpu_out_tensor)"
st_ort = f"torch.{test_name}(ort_tensor, out=ort_out_tensor)"
cpu_result = eval(compile(st_cpu, "<string>", "eval"))
ort_result = eval(compile(st_ort, "<string>", "eval"))
cpu_result = getattr(torch, test_name)(cpu_tensor, out=cpu_out_tensor)
ort_result = getattr(torch, test_name)(ort_tensor, out=ort_out_tensor)
assert torch.allclose(cpu_result, ort_result.cpu(), equal_nan=True)
assert torch.allclose(cpu_out_tensor, ort_out_tensor.cpu(), equal_nan=True)
@ -670,12 +663,9 @@ class OrtOpTests(unittest.TestCase):
for tensor_type in {torch.float, torch.bool}:
cpu_out_tensor = torch.tensor([], dtype=tensor_type)
ort_out_tensor = cpu_out_tensor.to(device)
cpu_a_b_result = eval(
compile("torch." + math_sign_ops + "(cpu_a, cpu_b, out=cpu_out_tensor)", "<string>", "eval")
)
ort_a_b_result = eval(
compile("torch." + math_sign_ops + "(ort_a, ort_b, out=ort_out_tensor)", "<string>", "eval")
)
cpu_a_b_result = getattr(torch, math_sign_ops)(cpu_a, cpu_b, out=cpu_out_tensor)
ort_a_b_result = getattr(torch, math_sign_ops)(ort_a, ort_b, out=ort_out_tensor)
assert torch.equal(cpu_a_b_result.to(device), ort_a_b_result)
assert torch.equal(cpu_out_tensor, ort_out_tensor.to("cpu"))
assert ort_out_tensor.dtype == tensor_type
@ -699,35 +689,15 @@ class OrtOpTests(unittest.TestCase):
cpu_out_tensor = torch.tensor([], dtype=torch.bool)
ort_out_tensor = cpu_out_tensor.to(device)
cpu_int_int_result = eval(
compile(
"torch." + math_sign_ops + "(cpu_tensor_int, cpu_scalar_int_lt, out=cpu_out_tensor)", "<string>", "eval"
)
)
cpu_int_int_gt_result = eval(
compile("torch." + math_sign_ops + "(cpu_tensor_int, cpu_scalar_int_gt)", "<string>", "eval")
)
cpu_float_float_lt_result = eval(
compile("torch." + math_sign_ops + "(cpu_tensor_float, float_lt)", "<string>", "eval")
)
cpu_float_float_gt_result = eval(
compile("torch." + math_sign_ops + "(cpu_tensor_float, float_gt)", "<string>", "eval")
)
cpu_int_int_result = getattr(torch, math_sign_ops)(cpu_tensor_int, cpu_scalar_int_lt, out=cpu_out_tensor)
cpu_int_int_gt_result = getattr(torch, math_sign_ops)(cpu_tensor_int, cpu_scalar_int_gt)
cpu_float_float_lt_result = getattr(torch, math_sign_ops)(cpu_tensor_float, float_lt)
cpu_float_float_gt_result = getattr(torch, math_sign_ops)(cpu_tensor_float, float_gt)
ort_int_int_result = eval(
compile(
"torch." + math_sign_ops + "(ort_tensor_int, ort_scalar_int_lt, out=ort_out_tensor)", "<string>", "eval"
)
)
ort_int_int_gt_result = eval(
compile("torch." + math_sign_ops + "(ort_tensor_int, ort_scalar_int_gt)", "<string>", "eval")
)
ort_float_float_lt_result = eval(
compile("torch." + math_sign_ops + "(ort_tensor_float, float_lt)", "<string>", "eval")
)
ort_float_float_gt_result = eval(
compile("torch." + math_sign_ops + "(ort_tensor_float, float_gt)", "<string>", "eval")
)
ort_int_int_result = getattr(torch, math_sign_ops)(ort_tensor_int, ort_scalar_int_lt, out=ort_out_tensor)
ort_int_int_gt_result = getattr(torch, math_sign_ops)(ort_tensor_int, ort_scalar_int_gt)
ort_float_float_lt_result = getattr(torch, math_sign_ops)(ort_tensor_float, float_lt)
ort_float_float_gt_result = getattr(torch, math_sign_ops)(ort_tensor_float, float_gt)
assert torch.equal(cpu_out_tensor, ort_out_tensor.to("cpu"))
assert torch.equal(cpu_int_int_result, ort_int_int_result.to("cpu"))
@ -735,88 +705,65 @@ class OrtOpTests(unittest.TestCase):
assert torch.equal(cpu_float_float_lt_result, ort_float_float_lt_result.to("cpu"))
assert torch.equal(cpu_float_float_gt_result, ort_float_float_gt_result.to("cpu"))
binary_ops = [ # [op, op_sign, alpha_supported]
["add", "+", True],
["sub", "-", True],
["mul", "*", False],
["div", "/", False],
binary_ops = [ # [op, op, alpha_supported]
["add", operator.add, True],
["sub", operator.sub, True],
["mul", operator.mul, False],
["div", operator.truediv, False],
]
@parameterized.expand(binary_ops, name_func=rename_func)
def test_op_binary_tensor(self, binary_op, op_sign, alpha_supported):
def test_op_binary_tensor(self, binary_op, op, alpha_supported):
device = self.get_device()
cpu_input = torch.rand(3, 1) # use broadcasting in the second dim.
ort_input = cpu_input.to(device)
cpu_other = torch.rand(3, 3)
ort_other = cpu_other.to(device)
# verify op_sign works
cpu_result = eval(compile("cpu_input " + op_sign + " cpu_other", "<string>", "eval"))
ort_result = eval(compile("ort_input " + op_sign + " ort_other", "<string>", "eval"))
# verify op works
cpu_result = op(cpu_input, cpu_other)
ort_result = op(ort_input, ort_other)
assert torch.allclose(cpu_result, ort_result.cpu())
# verify torch op with out param works
cpu_out_tensor = torch.tensor([])
ort_out_tensor = cpu_out_tensor.to(device)
cpu_result = eval(
compile("torch." + binary_op + "(cpu_input, cpu_other, out=cpu_out_tensor)", "<string>", "eval")
)
ort_result = eval(
compile("torch." + binary_op + "(ort_input, ort_other, out=ort_out_tensor)", "<string>", "eval")
)
cpu_result = getattr(torch, binary_op)(cpu_input, cpu_other, out=cpu_out_tensor)
ort_result = getattr(torch, binary_op)(ort_input, ort_other, out=ort_out_tensor)
assert torch.allclose(cpu_result, ort_result.cpu())
assert torch.allclose(cpu_out_tensor, ort_out_tensor.cpu())
if alpha_supported:
cpu_result = eval(
compile(
"torch." + binary_op + "(cpu_input, cpu_other, alpha=2.5, out=cpu_out_tensor)", "<string>", "eval"
)
)
ort_result = eval(
compile(
"torch." + binary_op + "(ort_input, ort_other, alpha=2.5, out=ort_out_tensor)", "<string>", "eval"
)
)
cpu_result = getattr(torch, binary_op)(cpu_input, cpu_other, alpha=2.5, out=cpu_out_tensor)
ort_result = getattr(torch, binary_op)(ort_input, ort_other, alpha=2.5, out=ort_out_tensor)
assert torch.allclose(cpu_result, ort_result.cpu())
assert torch.allclose(cpu_out_tensor, ort_out_tensor.cpu())
@parameterized.expand(binary_ops, name_func=rename_func)
def test_op_binary_scalar(self, binary_op, op_sign, alpha_supported):
def test_op_binary_scalar(self, binary_op, op, alpha_supported):
device = self.get_device()
cpu_input = torch.ones(3, 3)
ort_input = cpu_input.to(device)
cpu_other = 3.1
ort_other = 3.1
# verify op_sign works
cpu_result = eval(compile("cpu_input " + op_sign + " cpu_other", "<string>", "eval"))
ort_result = eval(compile("ort_input " + op_sign + " ort_other", "<string>", "eval"))
# verify op works
cpu_result = op(cpu_input, cpu_other)
ort_result = op(ort_input, ort_other)
assert torch.allclose(cpu_result, ort_result.cpu())
# verify torch op with out param works
cpu_out_tensor = torch.tensor([])
ort_out_tensor = cpu_out_tensor.to(device)
cpu_result = eval(
compile("torch." + binary_op + "(cpu_input, cpu_other, out=cpu_out_tensor)", "<string>", "eval")
)
ort_result = eval(
compile("torch." + binary_op + "(ort_input, ort_other, out=ort_out_tensor)", "<string>", "eval")
)
cpu_result = getattr(torch, binary_op)(cpu_input, cpu_other, out=cpu_out_tensor)
ort_result = getattr(torch, binary_op)(ort_input, ort_other, out=ort_out_tensor)
assert torch.allclose(cpu_result, ort_result.cpu())
assert torch.allclose(cpu_out_tensor, ort_out_tensor.cpu())
if alpha_supported:
cpu_result = eval(
compile(
"torch." + binary_op + "(cpu_input, cpu_other, alpha=2.5, out=cpu_out_tensor)", "<string>", "eval"
)
)
ort_result = eval(
compile(
"torch." + binary_op + "(ort_input, ort_other, alpha=2.5, out=ort_out_tensor)", "<string>", "eval"
)
)
cpu_result = getattr(torch, binary_op)(cpu_input, cpu_other, alpha=2.5, out=cpu_out_tensor)
ort_result = getattr(torch, binary_op)(ort_input, ort_other, alpha=2.5, out=ort_out_tensor)
assert torch.allclose(cpu_result, ort_result.cpu())
assert torch.allclose(cpu_out_tensor, ort_out_tensor.cpu())

View file

@ -3,19 +3,20 @@
# orttraining_test_ortmodule_api.py
import copy
import inspect
import itertools
import math
import os
import pickle
import random
import tempfile
import time
import unittest.mock
import warnings
from collections import OrderedDict, namedtuple
from inspect import signature
from time import sleep
from unittest.mock import patch
import _test_helpers
import numpy as np
import onnx
import pytest
import torch
@ -380,12 +381,12 @@ def run_before_test_session(request):
request.addfinalizer(remove_disable_fallback_from_env)
# TODO: This is a workaround for the problem that pytest is still cleaning up the previous test
# FIXME: This is a workaround for the problem that pytest is still cleaning up the previous test
# while the next task already start.
@pytest.fixture(autouse=True)
def run_before_tests():
# wait for 50ms before starting the next test
sleep(0.05)
time.sleep(0.05)
def _get_bert_for_sequence_classification_model(
@ -454,7 +455,7 @@ def test_forward_call_single_positional_argument():
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
ort_model = ORTModule(model)
# Check that the original forward signature is preserved.
assert signature(model.forward) == signature(ort_model.forward)
assert inspect.signature(model.forward) == inspect.signature(ort_model.forward)
x = torch.randn(N, D_in, device=device)
# Make sure model runs without any exception
prediction = ort_model(x)
@ -470,7 +471,7 @@ def test_forward_call_multiple_positional_arguments():
model = NeuralNetMultiplePositionalArguments(input_size=D_in, hidden_size=H, num_classes=D_out).to(device)
ort_model = ORTModule(model)
# Check that the original forward signature is preserved.
assert signature(model.forward) == signature(ort_model.forward)
assert inspect.signature(model.forward) == inspect.signature(ort_model.forward)
x = torch.randn(N, D_in, device=device)
y = torch.randn(N, D_in, device=device)
@ -535,41 +536,42 @@ def test_forward_call_positional_and_keyword_arguments():
prediction.backward()
_ONE = torch.FloatTensor([1])
@pytest.mark.parametrize(
"forward_statement",
"forward_function",
[
"model(one)",
"model(x=one)",
"model(one, None, None)",
"model(one, None, z=None)",
"model(one, None)",
"model(x=one, y=one)",
"model(y=one, x=one)",
"model(y=one, z=None, x=one)",
"model(one, None, z=one)",
"model(x=one, z=one)",
"model(one, z=one)",
"model(one, z=one, y=one)",
"model(one, one, one)",
"model(one, None, one)",
"model(z=one, x=one, y=one)",
"model(z=one, x=one, y=None)",
lambda model: model(_ONE),
lambda model: model(x=_ONE),
lambda model: model(_ONE, None, None),
lambda model: model(_ONE, None, z=None),
lambda model: model(_ONE, None),
lambda model: model(x=_ONE, y=_ONE),
lambda model: model(y=_ONE, x=_ONE),
lambda model: model(y=_ONE, z=None, x=_ONE),
lambda model: model(_ONE, None, z=_ONE),
lambda model: model(x=_ONE, z=_ONE),
lambda model: model(_ONE, z=_ONE),
lambda model: model(_ONE, z=_ONE, y=_ONE),
lambda model: model(_ONE, _ONE, _ONE),
lambda model: model(_ONE, None, _ONE),
lambda model: model(z=_ONE, x=_ONE, y=_ONE),
lambda model: model(z=_ONE, x=_ONE, y=None),
],
)
def test_compare_pytorch_forward_call_positional_and_keyword_arguments(forward_statement):
one = torch.FloatTensor([1])
def test_compare_pytorch_forward_call_positional_and_keyword_arguments(forward_function):
model = NeuralNetSimplePositionalAndKeywordArguments()
pytorch_result = eval(forward_statement + ".item()")
pytorch_result = forward_function(model).item()
model = NeuralNetSimplePositionalAndKeywordArguments()
model = ORTModule(model)
ortmodule_result = eval(forward_statement + ".item()")
ortmodule_result_again = eval(forward_statement + ".item()")
ortmodule_result = forward_function(model).item()
ortmodule_result_again = forward_function(model).item()
assert ortmodule_result == ortmodule_result_again
assert pytorch_result == ortmodule_result
prediction = eval(forward_statement).sum()
prediction = forward_function(model).sum()
prediction.backward()
@ -1661,8 +1663,6 @@ def test_aten_multinomial(input_shape, num_samples, replacement):
@pytest.mark.parametrize("input_shape", ([4, 2],))
def test_aten_argmax(input_shape):
import torch.nn.functional as F
class TopKGate(torch.nn.Module):
def forward(self, input: torch.Tensor):
indices = torch.argmax(input, dim=1)
@ -2189,7 +2189,6 @@ def test_ortmodule_inputs_with_dynamic_shape():
def test_bert_inputs_with_dynamic_shape():
# create pytorch model with dropout disabled
pt_model = _get_bert_for_sequence_classification_model(
"cuda", is_training=True, hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0
@ -2440,7 +2439,7 @@ def test_gpu_reserved_memory_with_torch_no_grad():
model_without_no_grad = ORTModule(model_without_no_grad)
mem_reserved_after_export_without_torch_no_grad = 0
with patch("torch.no_grad"):
with unittest.mock.patch("torch.no_grad"):
model_without_no_grad(x, attention_mask=y, labels=z)
mem_reserved_after_export_without_torch_no_grad = torch.cuda.memory_reserved(device)
@ -2608,9 +2607,7 @@ def test_exception_raised_for_custom_class_return_value_module(device):
y = torch.randn(N, D_in, device=device)
z = torch.randn(N, D_in, device=device)
from onnxruntime.training.ortmodule._fallback import _FallbackPolicy
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DATA):
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _fallback._FallbackPolicy.FALLBACK_UNSUPPORTED_DATA):
# Fallback
pt_out = pt_model(x, y, z)
ort_out = ort_model(x, y, z)
@ -2664,9 +2661,7 @@ def test_model_with_multiple_devices_cpu_cuda():
pt_model = MultipleDeviceModel()
x = torch.randn(20, 10)
from onnxruntime.training.ortmodule._fallback import _FallbackPolicy
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE):
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _fallback._FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE):
# Fallback
ort_model = ORTModule(copy.deepcopy(pt_model))
with pytest.raises(RuntimeError) as runtime_error:
@ -2695,9 +2690,8 @@ def test_model_with_multiple_devices_to_to():
pt_model = MultipleDeviceModel()
x = torch.randn(20, 10)
from onnxruntime.training.ortmodule._fallback import _FallbackPolicy
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE):
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _fallback._FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE):
# Fallback
with pytest.raises(RuntimeError) as runtime_error:
ort_model = ORTModule(copy.deepcopy(pt_model))
@ -2726,9 +2720,8 @@ def test_model_with_multiple_devices_to_cpu():
pt_model = MultipleDeviceModel()
x = torch.randn(20, 10)
from onnxruntime.training.ortmodule._fallback import _FallbackPolicy
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE):
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _fallback._FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE):
# Fallback
ort_model = ORTModule(copy.deepcopy(pt_model))
with pytest.raises(RuntimeError) as runtime_error:
@ -2757,9 +2750,8 @@ def test_model_with_multiple_devices_to_cuda():
pt_model = MultipleDeviceModel()
x = torch.randn(20, 10)
from onnxruntime.training.ortmodule._fallback import _FallbackPolicy
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE):
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _fallback._FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE):
# Fallback
ort_model = ORTModule(copy.deepcopy(pt_model))
with pytest.raises(RuntimeError) as runtime_error:
@ -2776,7 +2768,6 @@ def test_model_with_multiple_devices_to_cuda():
@pytest.mark.parametrize("device", ["cuda", "cuda:0", "cuda:1", "cuda:2"])
def test_model_with_different_cuda_devices(device):
# Trick to run this test in single GPU machines
device_id = _utils.get_device_index(device)
if device_id >= torch.cuda.device_count():
@ -2933,7 +2924,6 @@ def test_nested_return_value_module(device):
@pytest.mark.parametrize("data_device, model_device", (["cuda", "cpu"], ["cpu", "cuda"]))
def test_forward_data_and_model_on_different_devices(data_device, model_device):
os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED"
N, D_in, H, D_out = 64, 784, 500, 10
@ -2941,13 +2931,12 @@ def test_forward_data_and_model_on_different_devices(data_device, model_device):
ort_model = ORTModule(model)
# When exporting the model, ensure device is same between input data and model (else pytorch will raise while exporting)
x = torch.randn(N, D_in, device=model_device)
output = ort_model(x)
_ = ort_model(x)
# Now that the model has been exported, feed in data from device other than the model device
x = torch.randn(N, D_in, device=data_device)
from onnxruntime.training.ortmodule._fallback import ORTModuleDeviceException, _FallbackPolicy
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE):
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _fallback._FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE):
# Fallback
with pytest.raises(RuntimeError) as runtime_error:
ort_model(x)
@ -2956,7 +2945,7 @@ def test_forward_data_and_model_on_different_devices(data_device, model_device):
)
else:
# ORT backend
with pytest.raises(ORTModuleDeviceException) as runtime_error:
with pytest.raises(_fallback.ORTModuleDeviceException) as runtime_error:
ort_model(x)
assert (
f"Input argument to forward found on device {torch.device(x.device)}, but expected it to be on module device {ort_model._torch_module._execution_manager(ort_model._is_training())._device}."
@ -3067,7 +3056,6 @@ def test_model_wrapped_inside_torch_no_grad():
def test_model_initializer_requires_grad_changes_from_one_forward_to_next():
os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED"
device = "cuda"
@ -3129,7 +3117,7 @@ def test_model_with_registered_buffers():
model = NeuralNetWithRegisteredBuffer(D_in, H, D_out).to(device)
ort_model = ORTModule(model)
# Check that the original forward signature is preserved.
assert signature(model.forward) == signature(ort_model.forward)
assert inspect.signature(model.forward) == inspect.signature(ort_model.forward)
x = torch.randn(N, D_in, device=device)
# Make sure model runs without any exception
output = ort_model(x)
@ -3161,7 +3149,7 @@ def test_model_with_unused_registered_buffers():
model = UnusedBufferNet(D_in, H, D_out).to(device)
ort_model = ORTModule(model)
# Check that the original forward signature is preserved.
assert signature(model.forward) == signature(ort_model.forward)
assert inspect.signature(model.forward) == inspect.signature(ort_model.forward)
x = torch.randn(N, D_in, device=device)
# Make sure model runs without any exception
output = ort_model(x)
@ -3194,7 +3182,7 @@ def test_model_with_constant_and_registered_parameters():
model = NeuralNetWithRegisteredParamsWithConstant(D_in, H, D_out).to(device)
ort_model = ORTModule(model)
# Check that the original forward signature is preserved.
assert signature(model.forward) == signature(ort_model.forward)
assert inspect.signature(model.forward) == inspect.signature(ort_model.forward)
x = torch.randn(N, D_in, device=device)
# Make sure model runs without any exception
output = ort_model(x)
@ -3460,7 +3448,6 @@ def test_train_eval_with_various_outputs():
def test_forward_dynamic_args():
os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED"
device = "cuda"
@ -3474,7 +3461,6 @@ def test_forward_dynamic_args():
# Make sure model runs without any exception
for i in range(2):
# Test both train and inference mode
if i % 2 == 0:
model.train()
@ -3506,7 +3492,6 @@ def test_forward_dynamic_args():
def test_forward_dynamic_kwargs():
os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED"
one = torch.FloatTensor([1])
@ -3515,7 +3500,6 @@ def test_forward_dynamic_kwargs():
# Make sure model runs without any exception
for i in range(2):
# Test both train and inference mode
if i % 2 == 0:
model.train()
@ -3562,46 +3546,48 @@ def test_forward_dynamic_kwargs():
@pytest.mark.parametrize(
"forward_statement",
"forward_function",
[ # Only pos_X, pos_X as positionals
"model(pos_0, pos_1)",
lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0, pos_1),
# Only pos_X, pos_X as keywords
"model(pos_0=pos_0, pos_1=pos_1)",
lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0=pos_0, pos_1=pos_1),
# pos_X + *args, pos_X as positionals
"model(pos_0, pos_1, *args)",
lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0, pos_1, *args),
# pos_X + kw_X, pos_X as positionals
"model(pos_0, pos_1, kw_0=kw_0, kw_1=kw_1)",
lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0, pos_1, kw_0=kw_0, kw_1=kw_1),
# pos_X + kw_X, pos_X as keywords
"model(pos_0=pos_0, pos_1=pos_1, kw_0=kw_0, kw_1=kw_1)",
lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0=pos_0, pos_1=pos_1, kw_0=kw_0, kw_1=kw_1),
# pos_X + kw_X, pos_X as positionals (missing kw_1)
"model(pos_0, pos_1, kw_0=kw_0)",
lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0, pos_1, kw_0=kw_0),
# pos_X + kw_X, pos_X as keywords (missing kw_1)
"model(pos_0=pos_0, pos_1=pos_1, kw_0=kw_0)",
lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0=pos_0, pos_1=pos_1, kw_0=kw_0),
# pos_X + kw_X, pos_X as positionals (missing kw_0)
"model(pos_0, pos_1, kw_1=kw_1)",
lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0, pos_1, kw_1=kw_1),
# pos_X + kw_X, pos_X as keywords (missing kw_0)
"model(pos_0=pos_0, pos_1=pos_1, kw_1=kw_1)",
lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0=pos_0, pos_1=pos_1, kw_1=kw_1),
# pos_X + kwargs, pos_X as positionals
"model(pos_0, pos_1, **kwargs)",
lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0, pos_1, **kwargs),
# pos_X + kwargs, pos_X as keywords
"model(pos_0=pos_0, pos_1=pos_1, **kwargs)",
lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0=pos_0, pos_1=pos_1, **kwargs),
# pos_X + *args + kw_X, pos_X as positionals
"model(pos_0, pos_1, *args, kw_0=kw_0, kw_1=kw_1)",
lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0, pos_1, *args, kw_0=kw_0, kw_1=kw_1),
# pos_X + *args + kw_X, pos_X as positionals (missing kw_0)
"model(pos_0, pos_1, *args, kw_1=kw_1)",
lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0, pos_1, *args, kw_1=kw_1),
# pos_X + *args + kw_X, pos_X as positionals (missing kw_1)
"model(pos_0, pos_1, *args, kw_0=kw_0)",
lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0, pos_1, *args, kw_0=kw_0),
# pos_X + *args + kwargs, pos_X as positionals
"model(pos_0, pos_1, *args, **kwargs)",
lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0, pos_1, *args, **kwargs),
# pos_X + *args + kw_X + kwargs, pos_X as positionals
"model(pos_0, pos_1, *args, kw_0=kw_0, kw_1=kw_1, **kwargs)",
lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(
pos_0, pos_1, *args, kw_0=kw_0, kw_1=kw_1, **kwargs
),
# pos_X + *args + kw_X + kwargs, pos_X as positionals (missing kw_0)
"model(pos_0, pos_1, *args, kw_1=kw_1, **kwargs)",
lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0, pos_1, *args, kw_1=kw_1, **kwargs),
# pos_X + *args + kw_X + kwargs, pos_X as positionals (missing kw_1)
"model(pos_0, pos_1, *args, kw_0=kw_0, **kwargs)",
lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0, pos_1, *args, kw_0=kw_0, **kwargs),
],
)
def test_forward_call_kwargs_input(forward_statement):
def test_forward_call_kwargs_input(forward_function):
class KwargsNet(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(KwargsNet, self).__init__()
@ -3644,7 +3630,7 @@ def test_forward_call_kwargs_input(forward_statement):
kwargs = {"kwargs_0": torch.randn(N, D_in, device=device), "kwargs_1": torch.randn(D_in, D_in, device=device)}
# Training step
prediction = eval(forward_statement)
prediction = forward_function(model, pos_0, pos_1, kw_0, kw_1, args, kwargs)
assert prediction is not None
prediction = prediction.sum()
prediction.backward()
@ -3669,7 +3655,6 @@ def test_repro_iscontiguous():
def test_forward_call_default_input():
os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED"
class UnusedNet(torch.nn.Module):
@ -3795,7 +3780,6 @@ def test_forward_call_kwargs_input_unexpected_order():
def test_forward_call_lots_None():
os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED"
class NoneNet(torch.nn.Module):
@ -3943,7 +3927,6 @@ def test_primitive_inputs(bool_argument, int_argument, float_argument):
@pytest.mark.parametrize("bool_arguments", [(True, False), (False, True)])
def test_changing_bool_input_re_exports_model(bool_arguments):
os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED"
class PrimitiveTypesInputNet(torch.nn.Module):
@ -4116,7 +4099,6 @@ def test_output_order():
@pytest.mark.parametrize("device", ["cuda", "cpu", None])
def test_stateless_model_specified_device(device):
N, D_in, H, D_out = 32, 784, 500, 10
pt_model = StatelessModel().to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))
@ -4131,7 +4113,6 @@ def test_stateless_model_specified_device(device):
def test_stateless_model_unspecified_device():
N, D_in, H, D_out = 32, 784, 500, 10
pt_model = StatelessModel()
ort_model = ORTModule(copy.deepcopy(pt_model))
@ -4238,7 +4219,6 @@ def test_hf_save_pretrained():
def test_ortmodule_string_inputs_are_ignored():
pt_model = MyStrNet()
ort_model = ORTModule(copy.deepcopy(pt_model))
x = torch.randn(1, 2)
@ -4346,7 +4326,6 @@ def test_ortmodule_nested_list_input():
@pytest.mark.parametrize("mode", ["training", "inference"])
def test_debug_options_save_onnx_models_os_environment(mode):
device = "cuda"
N, D_in, H, D_out = 64, 784, 500, 10
# Create a temporary directory for the onnx_models
@ -4370,7 +4349,6 @@ def test_debug_options_save_onnx_models_os_environment(mode):
@pytest.mark.parametrize("mode", ["training", "inference"])
def test_debug_options_save_onnx_models_cwd(mode):
device = "cuda"
N, D_in, H, D_out = 64, 784, 500, 10
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
@ -4395,7 +4373,6 @@ def test_debug_options_save_onnx_models_cwd(mode):
def test_debug_options_save_onnx_models_validate_fail_on_non_writable_dir():
os.environ["ORTMODULE_SAVE_ONNX_PATH"] = "/non/existent/directory"
with pytest.raises(Exception) as ex_info:
_ = DebugOptions(save_onnx=True, onnx_prefix="my_model")
@ -4793,7 +4770,6 @@ def test_ortmodule_setattr_ortmodule_attribute():
def test_ortmodule_setattr_signals_model_changed():
os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED"
class UserNet(torch.nn.Module):
@ -4928,7 +4904,6 @@ def test_ortmodule_skip_check_load_from_os_env(policy_str, policy):
@pytest.mark.parametrize("is_training,deterministic", list(itertools.product([True, False], repeat=2)))
def test_ortmodule_determinism_flag(is_training, deterministic):
torch.use_deterministic_algorithms(deterministic)
N, D_in, H, D_out = 64, 784, 500, 10
@ -4940,9 +4915,7 @@ def test_ortmodule_determinism_flag(is_training, deterministic):
x = torch.randn(N, D_in)
_ = model(x)
from onnxruntime.training.ortmodule import _are_deterministic_algorithms_enabled
assert _are_deterministic_algorithms_enabled() is torch.are_deterministic_algorithms_enabled()
assert ortmodule_module._are_deterministic_algorithms_enabled() is torch.are_deterministic_algorithms_enabled()
def test_ortmodule_gradient_builder():
@ -5053,7 +5026,6 @@ def test_override_pytorch_exporter_kwargs_using_ortmodule_extension():
def test_ortmodule_fused_adam_optimizer_correctness():
torch.manual_seed(8888)
device = "cuda"
@ -5102,7 +5074,6 @@ def test_ortmodule_fused_adam_optimizer_correctness():
def test_ortmodule_fused_adam_optimizer_correctness_torch():
torch.manual_seed(8888)
device = "cuda"
@ -5225,13 +5196,11 @@ def test_tanh_grad():
def test__defined_from_envvar():
from onnxruntime.training import ortmodule
os.environ["DUMMY_ORTMODULE"] = "15"
assert ortmodule._defined_from_envvar("DUMMY_ORTMODULE", 14) == 15
assert ortmodule_module._defined_from_envvar("DUMMY_ORTMODULE", 14) == 15
os.environ["DUMMY_ORTMODULE"] = "15j"
with warnings.catch_warnings(record=True) as w:
assert ortmodule._defined_from_envvar("DUMMY_ORTMODULE", 14) == 14
assert ortmodule_module._defined_from_envvar("DUMMY_ORTMODULE", 14) == 14
assert len(w) == 1
assert issubclass(w[-1].category, UserWarning)
assert "Unable to overwrite constant" in str(w[-1].message)
@ -5262,12 +5231,10 @@ def test_sigmoid_grad_opset13():
N, D_in, H, D_out = 120, 15360, 500, 15360
pt_model = NeuralNetSigmoid(D_in, H, D_out).to(device)
from onnxruntime.training import ortmodule
old_opst_cst = ortmodule.ONNX_OPSET_VERSION
old_opst_cst = ortmodule_module.ONNX_OPSET_VERSION
old_opset = os.getenv("ORTMODULE_ONNX_OPSET_VERSION", None)
os.environ["ORTMODULE_ONNX_OPSET_VERSION"] = "13"
assert ortmodule.ONNX_OPSET_VERSION == 15
assert ortmodule_module.ONNX_OPSET_VERSION == 15
ort_model = ORTModule(copy.deepcopy(pt_model))
@ -5293,8 +5260,8 @@ def test_sigmoid_grad_opset13():
del os.environ["ORTMODULE_ONNX_OPSET_VERSION"]
else:
os.environ["ORTMODULE_ONNX_OPSET_VERSION"] = old_opset
assert ortmodule.ONNX_OPSET_VERSION == 13
ortmodule.ONNX_OPSET_VERSION = old_opst_cst
assert ortmodule_module.ONNX_OPSET_VERSION == 13
ortmodule_module.ONNX_OPSET_VERSION = old_opst_cst
@pytest.mark.parametrize("opset_version", [12, 13, 14, 15])
@ -5307,10 +5274,7 @@ def test_opset_version_change(opset_version):
ort_model = ORTModule(model)
# Must import a namespace containing ONNX_OPSET_VERSION, not ONNX_OPSET_VERSION directly
from onnxruntime.training import ortmodule
ortmodule.ONNX_OPSET_VERSION = opset_version
ortmodule_module.ONNX_OPSET_VERSION = opset_version
# Make sure model runs without any exception
prediction = ort_model(x)
@ -5324,7 +5288,6 @@ def test_opset_version_change(opset_version):
def test_serialize_ortmodule():
device = "cuda"
N, D_in, H, D_out = 64, 784, 500, 10
pt_model = SerializationNet(D_in, H, D_out).to(device)
@ -5455,8 +5418,6 @@ def test_check_opset_is_default_opset_after_training():
def test_random_states_unchanged_for_ortmodule():
import numpy
os.environ["ORTMODULE_FALLBACK_RETRY"] = "False"
class NeuralNetSlice(torch.nn.Module):
@ -5473,8 +5434,8 @@ def test_random_states_unchanged_for_ortmodule():
if isinstance(a, tuple):
assert len(a) == len(b)
return all([random_state_equal(a_i, b_i) for a_i, b_i in zip(a, b)])
if isinstance(a, numpy.ndarray):
return numpy.array_equal(a, b)
if isinstance(a, np.ndarray):
return np.array_equal(a, b)
if isinstance(a, torch.Tensor):
return torch.equal(a, b)
return a == b