mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Remove the use of eval in test code (#15097)
### Description Remove the use of `eval` in test code so we don't (1) use eval and (2) create "unused" local vars that ruff will remove. Predecessor to #15085
This commit is contained in:
parent
c964da7ea2
commit
bdd7bd084c
2 changed files with 121 additions and 213 deletions
|
|
@ -3,12 +3,12 @@
|
|||
|
||||
# pylint: disable=missing-docstring, too-many-public-methods, no-member
|
||||
|
||||
import operator
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime_pybind11_state as torch_ort
|
||||
import torch
|
||||
from parameterized import parameterized, param
|
||||
from parameterized import param, parameterized
|
||||
|
||||
|
||||
class OrtOpTests(unittest.TestCase):
|
||||
|
|
@ -570,11 +570,6 @@ class OrtOpTests(unittest.TestCase):
|
|||
# for floor and erf, the ort produces a roundoff error for NaN input, but cpu keeps it a NaN.
|
||||
# Thus, we use nan_to_num to ensure actual numbers are passed in.
|
||||
|
||||
# As many of the following use eval and make it appear to pylint that there are many unused variables,
|
||||
# we disable those warnings
|
||||
|
||||
# pylint: disable=eval-used, unused-argument, unused-variable, no-self-argument,
|
||||
|
||||
ops = [
|
||||
["abs", torch.tensor([-1, -2, 3, -6, -7])],
|
||||
["acos"],
|
||||
|
|
@ -616,20 +611,20 @@ class OrtOpTests(unittest.TestCase):
|
|||
# @parameterized.expand generate test methods for ops and using name_func we renaming the test to be test_{ops}
|
||||
@parameterized.expand(ops, name_func=rename_func)
|
||||
def test_op(self, test_name, tensor_test=torch.rand(6)):
|
||||
# compile eval- creates a code object that evaluates the operator (for example torch.abs(tensor_test)) and returns its result.
|
||||
cpu_result = eval(compile("torch." + test_name + "(tensor_test)", "<string>", "eval"))
|
||||
ort_result = eval(compile("torch." + test_name + "(tensor_test.to(self.get_device()))", "<string>", "eval"))
|
||||
cpu_result = getattr(torch, test_name)(tensor_test)
|
||||
ort_result = getattr(torch, test_name)(tensor_test.to(self.get_device()))
|
||||
|
||||
assert torch.allclose(cpu_result, ort_result.cpu(), equal_nan=True)
|
||||
|
||||
@parameterized.expand(ops, name_func=rename_func)
|
||||
def test_op_(self, test_name, tensor_test=torch.rand(6)):
|
||||
def test_op_inplace(self, test_name, tensor_test=torch.rand(6)):
|
||||
device = self.get_device()
|
||||
|
||||
cpu_tensor = tensor_test
|
||||
ort_tensor = cpu_tensor.to(device)
|
||||
|
||||
eval(compile("torch." + test_name + "_(cpu_tensor)", "<string>", "eval"))
|
||||
eval(compile("torch." + test_name + "_(ort_tensor)", "<string>", "eval"))
|
||||
getattr(torch, test_name + "_")(cpu_tensor)
|
||||
getattr(torch, test_name + "_")(ort_tensor)
|
||||
|
||||
assert torch.allclose(cpu_tensor, ort_tensor.cpu(), equal_nan=True)
|
||||
|
||||
|
|
@ -648,10 +643,8 @@ class OrtOpTests(unittest.TestCase):
|
|||
cpu_out_tensor = torch.tensor([], dtype=tensor_test.dtype)
|
||||
ort_out_tensor = cpu_out_tensor.to(device)
|
||||
|
||||
st_cpu = f"torch.{test_name}(cpu_tensor, out=cpu_out_tensor)"
|
||||
st_ort = f"torch.{test_name}(ort_tensor, out=ort_out_tensor)"
|
||||
cpu_result = eval(compile(st_cpu, "<string>", "eval"))
|
||||
ort_result = eval(compile(st_ort, "<string>", "eval"))
|
||||
cpu_result = getattr(torch, test_name)(cpu_tensor, out=cpu_out_tensor)
|
||||
ort_result = getattr(torch, test_name)(ort_tensor, out=ort_out_tensor)
|
||||
|
||||
assert torch.allclose(cpu_result, ort_result.cpu(), equal_nan=True)
|
||||
assert torch.allclose(cpu_out_tensor, ort_out_tensor.cpu(), equal_nan=True)
|
||||
|
|
@ -670,12 +663,9 @@ class OrtOpTests(unittest.TestCase):
|
|||
for tensor_type in {torch.float, torch.bool}:
|
||||
cpu_out_tensor = torch.tensor([], dtype=tensor_type)
|
||||
ort_out_tensor = cpu_out_tensor.to(device)
|
||||
cpu_a_b_result = eval(
|
||||
compile("torch." + math_sign_ops + "(cpu_a, cpu_b, out=cpu_out_tensor)", "<string>", "eval")
|
||||
)
|
||||
ort_a_b_result = eval(
|
||||
compile("torch." + math_sign_ops + "(ort_a, ort_b, out=ort_out_tensor)", "<string>", "eval")
|
||||
)
|
||||
cpu_a_b_result = getattr(torch, math_sign_ops)(cpu_a, cpu_b, out=cpu_out_tensor)
|
||||
ort_a_b_result = getattr(torch, math_sign_ops)(ort_a, ort_b, out=ort_out_tensor)
|
||||
|
||||
assert torch.equal(cpu_a_b_result.to(device), ort_a_b_result)
|
||||
assert torch.equal(cpu_out_tensor, ort_out_tensor.to("cpu"))
|
||||
assert ort_out_tensor.dtype == tensor_type
|
||||
|
|
@ -699,35 +689,15 @@ class OrtOpTests(unittest.TestCase):
|
|||
cpu_out_tensor = torch.tensor([], dtype=torch.bool)
|
||||
ort_out_tensor = cpu_out_tensor.to(device)
|
||||
|
||||
cpu_int_int_result = eval(
|
||||
compile(
|
||||
"torch." + math_sign_ops + "(cpu_tensor_int, cpu_scalar_int_lt, out=cpu_out_tensor)", "<string>", "eval"
|
||||
)
|
||||
)
|
||||
cpu_int_int_gt_result = eval(
|
||||
compile("torch." + math_sign_ops + "(cpu_tensor_int, cpu_scalar_int_gt)", "<string>", "eval")
|
||||
)
|
||||
cpu_float_float_lt_result = eval(
|
||||
compile("torch." + math_sign_ops + "(cpu_tensor_float, float_lt)", "<string>", "eval")
|
||||
)
|
||||
cpu_float_float_gt_result = eval(
|
||||
compile("torch." + math_sign_ops + "(cpu_tensor_float, float_gt)", "<string>", "eval")
|
||||
)
|
||||
cpu_int_int_result = getattr(torch, math_sign_ops)(cpu_tensor_int, cpu_scalar_int_lt, out=cpu_out_tensor)
|
||||
cpu_int_int_gt_result = getattr(torch, math_sign_ops)(cpu_tensor_int, cpu_scalar_int_gt)
|
||||
cpu_float_float_lt_result = getattr(torch, math_sign_ops)(cpu_tensor_float, float_lt)
|
||||
cpu_float_float_gt_result = getattr(torch, math_sign_ops)(cpu_tensor_float, float_gt)
|
||||
|
||||
ort_int_int_result = eval(
|
||||
compile(
|
||||
"torch." + math_sign_ops + "(ort_tensor_int, ort_scalar_int_lt, out=ort_out_tensor)", "<string>", "eval"
|
||||
)
|
||||
)
|
||||
ort_int_int_gt_result = eval(
|
||||
compile("torch." + math_sign_ops + "(ort_tensor_int, ort_scalar_int_gt)", "<string>", "eval")
|
||||
)
|
||||
ort_float_float_lt_result = eval(
|
||||
compile("torch." + math_sign_ops + "(ort_tensor_float, float_lt)", "<string>", "eval")
|
||||
)
|
||||
ort_float_float_gt_result = eval(
|
||||
compile("torch." + math_sign_ops + "(ort_tensor_float, float_gt)", "<string>", "eval")
|
||||
)
|
||||
ort_int_int_result = getattr(torch, math_sign_ops)(ort_tensor_int, ort_scalar_int_lt, out=ort_out_tensor)
|
||||
ort_int_int_gt_result = getattr(torch, math_sign_ops)(ort_tensor_int, ort_scalar_int_gt)
|
||||
ort_float_float_lt_result = getattr(torch, math_sign_ops)(ort_tensor_float, float_lt)
|
||||
ort_float_float_gt_result = getattr(torch, math_sign_ops)(ort_tensor_float, float_gt)
|
||||
|
||||
assert torch.equal(cpu_out_tensor, ort_out_tensor.to("cpu"))
|
||||
assert torch.equal(cpu_int_int_result, ort_int_int_result.to("cpu"))
|
||||
|
|
@ -735,88 +705,65 @@ class OrtOpTests(unittest.TestCase):
|
|||
assert torch.equal(cpu_float_float_lt_result, ort_float_float_lt_result.to("cpu"))
|
||||
assert torch.equal(cpu_float_float_gt_result, ort_float_float_gt_result.to("cpu"))
|
||||
|
||||
binary_ops = [ # [op, op_sign, alpha_supported]
|
||||
["add", "+", True],
|
||||
["sub", "-", True],
|
||||
["mul", "*", False],
|
||||
["div", "/", False],
|
||||
binary_ops = [ # [op, op, alpha_supported]
|
||||
["add", operator.add, True],
|
||||
["sub", operator.sub, True],
|
||||
["mul", operator.mul, False],
|
||||
["div", operator.truediv, False],
|
||||
]
|
||||
|
||||
@parameterized.expand(binary_ops, name_func=rename_func)
|
||||
def test_op_binary_tensor(self, binary_op, op_sign, alpha_supported):
|
||||
def test_op_binary_tensor(self, binary_op, op, alpha_supported):
|
||||
device = self.get_device()
|
||||
cpu_input = torch.rand(3, 1) # use broadcasting in the second dim.
|
||||
ort_input = cpu_input.to(device)
|
||||
cpu_other = torch.rand(3, 3)
|
||||
ort_other = cpu_other.to(device)
|
||||
|
||||
# verify op_sign works
|
||||
cpu_result = eval(compile("cpu_input " + op_sign + " cpu_other", "<string>", "eval"))
|
||||
ort_result = eval(compile("ort_input " + op_sign + " ort_other", "<string>", "eval"))
|
||||
# verify op works
|
||||
cpu_result = op(cpu_input, cpu_other)
|
||||
ort_result = op(ort_input, ort_other)
|
||||
assert torch.allclose(cpu_result, ort_result.cpu())
|
||||
|
||||
# verify torch op with out param works
|
||||
cpu_out_tensor = torch.tensor([])
|
||||
ort_out_tensor = cpu_out_tensor.to(device)
|
||||
cpu_result = eval(
|
||||
compile("torch." + binary_op + "(cpu_input, cpu_other, out=cpu_out_tensor)", "<string>", "eval")
|
||||
)
|
||||
ort_result = eval(
|
||||
compile("torch." + binary_op + "(ort_input, ort_other, out=ort_out_tensor)", "<string>", "eval")
|
||||
)
|
||||
cpu_result = getattr(torch, binary_op)(cpu_input, cpu_other, out=cpu_out_tensor)
|
||||
ort_result = getattr(torch, binary_op)(ort_input, ort_other, out=ort_out_tensor)
|
||||
|
||||
assert torch.allclose(cpu_result, ort_result.cpu())
|
||||
assert torch.allclose(cpu_out_tensor, ort_out_tensor.cpu())
|
||||
|
||||
if alpha_supported:
|
||||
cpu_result = eval(
|
||||
compile(
|
||||
"torch." + binary_op + "(cpu_input, cpu_other, alpha=2.5, out=cpu_out_tensor)", "<string>", "eval"
|
||||
)
|
||||
)
|
||||
ort_result = eval(
|
||||
compile(
|
||||
"torch." + binary_op + "(ort_input, ort_other, alpha=2.5, out=ort_out_tensor)", "<string>", "eval"
|
||||
)
|
||||
)
|
||||
cpu_result = getattr(torch, binary_op)(cpu_input, cpu_other, alpha=2.5, out=cpu_out_tensor)
|
||||
ort_result = getattr(torch, binary_op)(ort_input, ort_other, alpha=2.5, out=ort_out_tensor)
|
||||
assert torch.allclose(cpu_result, ort_result.cpu())
|
||||
assert torch.allclose(cpu_out_tensor, ort_out_tensor.cpu())
|
||||
|
||||
@parameterized.expand(binary_ops, name_func=rename_func)
|
||||
def test_op_binary_scalar(self, binary_op, op_sign, alpha_supported):
|
||||
def test_op_binary_scalar(self, binary_op, op, alpha_supported):
|
||||
device = self.get_device()
|
||||
cpu_input = torch.ones(3, 3)
|
||||
ort_input = cpu_input.to(device)
|
||||
cpu_other = 3.1
|
||||
ort_other = 3.1
|
||||
|
||||
# verify op_sign works
|
||||
cpu_result = eval(compile("cpu_input " + op_sign + " cpu_other", "<string>", "eval"))
|
||||
ort_result = eval(compile("ort_input " + op_sign + " ort_other", "<string>", "eval"))
|
||||
# verify op works
|
||||
cpu_result = op(cpu_input, cpu_other)
|
||||
ort_result = op(ort_input, ort_other)
|
||||
assert torch.allclose(cpu_result, ort_result.cpu())
|
||||
|
||||
# verify torch op with out param works
|
||||
cpu_out_tensor = torch.tensor([])
|
||||
ort_out_tensor = cpu_out_tensor.to(device)
|
||||
cpu_result = eval(
|
||||
compile("torch." + binary_op + "(cpu_input, cpu_other, out=cpu_out_tensor)", "<string>", "eval")
|
||||
)
|
||||
ort_result = eval(
|
||||
compile("torch." + binary_op + "(ort_input, ort_other, out=ort_out_tensor)", "<string>", "eval")
|
||||
)
|
||||
cpu_result = getattr(torch, binary_op)(cpu_input, cpu_other, out=cpu_out_tensor)
|
||||
ort_result = getattr(torch, binary_op)(ort_input, ort_other, out=ort_out_tensor)
|
||||
assert torch.allclose(cpu_result, ort_result.cpu())
|
||||
assert torch.allclose(cpu_out_tensor, ort_out_tensor.cpu())
|
||||
|
||||
if alpha_supported:
|
||||
cpu_result = eval(
|
||||
compile(
|
||||
"torch." + binary_op + "(cpu_input, cpu_other, alpha=2.5, out=cpu_out_tensor)", "<string>", "eval"
|
||||
)
|
||||
)
|
||||
ort_result = eval(
|
||||
compile(
|
||||
"torch." + binary_op + "(ort_input, ort_other, alpha=2.5, out=ort_out_tensor)", "<string>", "eval"
|
||||
)
|
||||
)
|
||||
cpu_result = getattr(torch, binary_op)(cpu_input, cpu_other, alpha=2.5, out=cpu_out_tensor)
|
||||
ort_result = getattr(torch, binary_op)(ort_input, ort_other, alpha=2.5, out=ort_out_tensor)
|
||||
assert torch.allclose(cpu_result, ort_result.cpu())
|
||||
assert torch.allclose(cpu_out_tensor, ort_out_tensor.cpu())
|
||||
|
||||
|
|
|
|||
|
|
@ -3,19 +3,20 @@
|
|||
# orttraining_test_ortmodule_api.py
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
import tempfile
|
||||
import time
|
||||
import unittest.mock
|
||||
import warnings
|
||||
from collections import OrderedDict, namedtuple
|
||||
from inspect import signature
|
||||
from time import sleep
|
||||
from unittest.mock import patch
|
||||
|
||||
import _test_helpers
|
||||
import numpy as np
|
||||
import onnx
|
||||
import pytest
|
||||
import torch
|
||||
|
|
@ -380,12 +381,12 @@ def run_before_test_session(request):
|
|||
request.addfinalizer(remove_disable_fallback_from_env)
|
||||
|
||||
|
||||
# TODO: This is a workaround for the problem that pytest is still cleaning up the previous test
|
||||
# FIXME: This is a workaround for the problem that pytest is still cleaning up the previous test
|
||||
# while the next task already start.
|
||||
@pytest.fixture(autouse=True)
|
||||
def run_before_tests():
|
||||
# wait for 50ms before starting the next test
|
||||
sleep(0.05)
|
||||
time.sleep(0.05)
|
||||
|
||||
|
||||
def _get_bert_for_sequence_classification_model(
|
||||
|
|
@ -454,7 +455,7 @@ def test_forward_call_single_positional_argument():
|
|||
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
|
||||
ort_model = ORTModule(model)
|
||||
# Check that the original forward signature is preserved.
|
||||
assert signature(model.forward) == signature(ort_model.forward)
|
||||
assert inspect.signature(model.forward) == inspect.signature(ort_model.forward)
|
||||
x = torch.randn(N, D_in, device=device)
|
||||
# Make sure model runs without any exception
|
||||
prediction = ort_model(x)
|
||||
|
|
@ -470,7 +471,7 @@ def test_forward_call_multiple_positional_arguments():
|
|||
model = NeuralNetMultiplePositionalArguments(input_size=D_in, hidden_size=H, num_classes=D_out).to(device)
|
||||
ort_model = ORTModule(model)
|
||||
# Check that the original forward signature is preserved.
|
||||
assert signature(model.forward) == signature(ort_model.forward)
|
||||
assert inspect.signature(model.forward) == inspect.signature(ort_model.forward)
|
||||
x = torch.randn(N, D_in, device=device)
|
||||
y = torch.randn(N, D_in, device=device)
|
||||
|
||||
|
|
@ -535,41 +536,42 @@ def test_forward_call_positional_and_keyword_arguments():
|
|||
prediction.backward()
|
||||
|
||||
|
||||
_ONE = torch.FloatTensor([1])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"forward_statement",
|
||||
"forward_function",
|
||||
[
|
||||
"model(one)",
|
||||
"model(x=one)",
|
||||
"model(one, None, None)",
|
||||
"model(one, None, z=None)",
|
||||
"model(one, None)",
|
||||
"model(x=one, y=one)",
|
||||
"model(y=one, x=one)",
|
||||
"model(y=one, z=None, x=one)",
|
||||
"model(one, None, z=one)",
|
||||
"model(x=one, z=one)",
|
||||
"model(one, z=one)",
|
||||
"model(one, z=one, y=one)",
|
||||
"model(one, one, one)",
|
||||
"model(one, None, one)",
|
||||
"model(z=one, x=one, y=one)",
|
||||
"model(z=one, x=one, y=None)",
|
||||
lambda model: model(_ONE),
|
||||
lambda model: model(x=_ONE),
|
||||
lambda model: model(_ONE, None, None),
|
||||
lambda model: model(_ONE, None, z=None),
|
||||
lambda model: model(_ONE, None),
|
||||
lambda model: model(x=_ONE, y=_ONE),
|
||||
lambda model: model(y=_ONE, x=_ONE),
|
||||
lambda model: model(y=_ONE, z=None, x=_ONE),
|
||||
lambda model: model(_ONE, None, z=_ONE),
|
||||
lambda model: model(x=_ONE, z=_ONE),
|
||||
lambda model: model(_ONE, z=_ONE),
|
||||
lambda model: model(_ONE, z=_ONE, y=_ONE),
|
||||
lambda model: model(_ONE, _ONE, _ONE),
|
||||
lambda model: model(_ONE, None, _ONE),
|
||||
lambda model: model(z=_ONE, x=_ONE, y=_ONE),
|
||||
lambda model: model(z=_ONE, x=_ONE, y=None),
|
||||
],
|
||||
)
|
||||
def test_compare_pytorch_forward_call_positional_and_keyword_arguments(forward_statement):
|
||||
one = torch.FloatTensor([1])
|
||||
|
||||
def test_compare_pytorch_forward_call_positional_and_keyword_arguments(forward_function):
|
||||
model = NeuralNetSimplePositionalAndKeywordArguments()
|
||||
pytorch_result = eval(forward_statement + ".item()")
|
||||
pytorch_result = forward_function(model).item()
|
||||
|
||||
model = NeuralNetSimplePositionalAndKeywordArguments()
|
||||
model = ORTModule(model)
|
||||
ortmodule_result = eval(forward_statement + ".item()")
|
||||
ortmodule_result_again = eval(forward_statement + ".item()")
|
||||
ortmodule_result = forward_function(model).item()
|
||||
ortmodule_result_again = forward_function(model).item()
|
||||
assert ortmodule_result == ortmodule_result_again
|
||||
assert pytorch_result == ortmodule_result
|
||||
|
||||
prediction = eval(forward_statement).sum()
|
||||
prediction = forward_function(model).sum()
|
||||
prediction.backward()
|
||||
|
||||
|
||||
|
|
@ -1661,8 +1663,6 @@ def test_aten_multinomial(input_shape, num_samples, replacement):
|
|||
|
||||
@pytest.mark.parametrize("input_shape", ([4, 2],))
|
||||
def test_aten_argmax(input_shape):
|
||||
import torch.nn.functional as F
|
||||
|
||||
class TopKGate(torch.nn.Module):
|
||||
def forward(self, input: torch.Tensor):
|
||||
indices = torch.argmax(input, dim=1)
|
||||
|
|
@ -2189,7 +2189,6 @@ def test_ortmodule_inputs_with_dynamic_shape():
|
|||
|
||||
|
||||
def test_bert_inputs_with_dynamic_shape():
|
||||
|
||||
# create pytorch model with dropout disabled
|
||||
pt_model = _get_bert_for_sequence_classification_model(
|
||||
"cuda", is_training=True, hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0
|
||||
|
|
@ -2440,7 +2439,7 @@ def test_gpu_reserved_memory_with_torch_no_grad():
|
|||
model_without_no_grad = ORTModule(model_without_no_grad)
|
||||
mem_reserved_after_export_without_torch_no_grad = 0
|
||||
|
||||
with patch("torch.no_grad"):
|
||||
with unittest.mock.patch("torch.no_grad"):
|
||||
model_without_no_grad(x, attention_mask=y, labels=z)
|
||||
mem_reserved_after_export_without_torch_no_grad = torch.cuda.memory_reserved(device)
|
||||
|
||||
|
|
@ -2608,9 +2607,7 @@ def test_exception_raised_for_custom_class_return_value_module(device):
|
|||
y = torch.randn(N, D_in, device=device)
|
||||
z = torch.randn(N, D_in, device=device)
|
||||
|
||||
from onnxruntime.training.ortmodule._fallback import _FallbackPolicy
|
||||
|
||||
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DATA):
|
||||
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _fallback._FallbackPolicy.FALLBACK_UNSUPPORTED_DATA):
|
||||
# Fallback
|
||||
pt_out = pt_model(x, y, z)
|
||||
ort_out = ort_model(x, y, z)
|
||||
|
|
@ -2664,9 +2661,7 @@ def test_model_with_multiple_devices_cpu_cuda():
|
|||
pt_model = MultipleDeviceModel()
|
||||
x = torch.randn(20, 10)
|
||||
|
||||
from onnxruntime.training.ortmodule._fallback import _FallbackPolicy
|
||||
|
||||
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE):
|
||||
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _fallback._FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE):
|
||||
# Fallback
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
with pytest.raises(RuntimeError) as runtime_error:
|
||||
|
|
@ -2695,9 +2690,8 @@ def test_model_with_multiple_devices_to_to():
|
|||
|
||||
pt_model = MultipleDeviceModel()
|
||||
x = torch.randn(20, 10)
|
||||
from onnxruntime.training.ortmodule._fallback import _FallbackPolicy
|
||||
|
||||
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE):
|
||||
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _fallback._FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE):
|
||||
# Fallback
|
||||
with pytest.raises(RuntimeError) as runtime_error:
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
|
|
@ -2726,9 +2720,8 @@ def test_model_with_multiple_devices_to_cpu():
|
|||
|
||||
pt_model = MultipleDeviceModel()
|
||||
x = torch.randn(20, 10)
|
||||
from onnxruntime.training.ortmodule._fallback import _FallbackPolicy
|
||||
|
||||
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE):
|
||||
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _fallback._FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE):
|
||||
# Fallback
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
with pytest.raises(RuntimeError) as runtime_error:
|
||||
|
|
@ -2757,9 +2750,8 @@ def test_model_with_multiple_devices_to_cuda():
|
|||
|
||||
pt_model = MultipleDeviceModel()
|
||||
x = torch.randn(20, 10)
|
||||
from onnxruntime.training.ortmodule._fallback import _FallbackPolicy
|
||||
|
||||
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE):
|
||||
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _fallback._FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE):
|
||||
# Fallback
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
with pytest.raises(RuntimeError) as runtime_error:
|
||||
|
|
@ -2776,7 +2768,6 @@ def test_model_with_multiple_devices_to_cuda():
|
|||
|
||||
@pytest.mark.parametrize("device", ["cuda", "cuda:0", "cuda:1", "cuda:2"])
|
||||
def test_model_with_different_cuda_devices(device):
|
||||
|
||||
# Trick to run this test in single GPU machines
|
||||
device_id = _utils.get_device_index(device)
|
||||
if device_id >= torch.cuda.device_count():
|
||||
|
|
@ -2933,7 +2924,6 @@ def test_nested_return_value_module(device):
|
|||
|
||||
@pytest.mark.parametrize("data_device, model_device", (["cuda", "cpu"], ["cpu", "cuda"]))
|
||||
def test_forward_data_and_model_on_different_devices(data_device, model_device):
|
||||
|
||||
os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED"
|
||||
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
|
|
@ -2941,13 +2931,12 @@ def test_forward_data_and_model_on_different_devices(data_device, model_device):
|
|||
ort_model = ORTModule(model)
|
||||
# When exporting the model, ensure device is same between input data and model (else pytorch will raise while exporting)
|
||||
x = torch.randn(N, D_in, device=model_device)
|
||||
output = ort_model(x)
|
||||
_ = ort_model(x)
|
||||
|
||||
# Now that the model has been exported, feed in data from device other than the model device
|
||||
x = torch.randn(N, D_in, device=data_device)
|
||||
from onnxruntime.training.ortmodule._fallback import ORTModuleDeviceException, _FallbackPolicy
|
||||
|
||||
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE):
|
||||
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _fallback._FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE):
|
||||
# Fallback
|
||||
with pytest.raises(RuntimeError) as runtime_error:
|
||||
ort_model(x)
|
||||
|
|
@ -2956,7 +2945,7 @@ def test_forward_data_and_model_on_different_devices(data_device, model_device):
|
|||
)
|
||||
else:
|
||||
# ORT backend
|
||||
with pytest.raises(ORTModuleDeviceException) as runtime_error:
|
||||
with pytest.raises(_fallback.ORTModuleDeviceException) as runtime_error:
|
||||
ort_model(x)
|
||||
assert (
|
||||
f"Input argument to forward found on device {torch.device(x.device)}, but expected it to be on module device {ort_model._torch_module._execution_manager(ort_model._is_training())._device}."
|
||||
|
|
@ -3067,7 +3056,6 @@ def test_model_wrapped_inside_torch_no_grad():
|
|||
|
||||
|
||||
def test_model_initializer_requires_grad_changes_from_one_forward_to_next():
|
||||
|
||||
os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED"
|
||||
|
||||
device = "cuda"
|
||||
|
|
@ -3129,7 +3117,7 @@ def test_model_with_registered_buffers():
|
|||
model = NeuralNetWithRegisteredBuffer(D_in, H, D_out).to(device)
|
||||
ort_model = ORTModule(model)
|
||||
# Check that the original forward signature is preserved.
|
||||
assert signature(model.forward) == signature(ort_model.forward)
|
||||
assert inspect.signature(model.forward) == inspect.signature(ort_model.forward)
|
||||
x = torch.randn(N, D_in, device=device)
|
||||
# Make sure model runs without any exception
|
||||
output = ort_model(x)
|
||||
|
|
@ -3161,7 +3149,7 @@ def test_model_with_unused_registered_buffers():
|
|||
model = UnusedBufferNet(D_in, H, D_out).to(device)
|
||||
ort_model = ORTModule(model)
|
||||
# Check that the original forward signature is preserved.
|
||||
assert signature(model.forward) == signature(ort_model.forward)
|
||||
assert inspect.signature(model.forward) == inspect.signature(ort_model.forward)
|
||||
x = torch.randn(N, D_in, device=device)
|
||||
# Make sure model runs without any exception
|
||||
output = ort_model(x)
|
||||
|
|
@ -3194,7 +3182,7 @@ def test_model_with_constant_and_registered_parameters():
|
|||
model = NeuralNetWithRegisteredParamsWithConstant(D_in, H, D_out).to(device)
|
||||
ort_model = ORTModule(model)
|
||||
# Check that the original forward signature is preserved.
|
||||
assert signature(model.forward) == signature(ort_model.forward)
|
||||
assert inspect.signature(model.forward) == inspect.signature(ort_model.forward)
|
||||
x = torch.randn(N, D_in, device=device)
|
||||
# Make sure model runs without any exception
|
||||
output = ort_model(x)
|
||||
|
|
@ -3460,7 +3448,6 @@ def test_train_eval_with_various_outputs():
|
|||
|
||||
|
||||
def test_forward_dynamic_args():
|
||||
|
||||
os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED"
|
||||
|
||||
device = "cuda"
|
||||
|
|
@ -3474,7 +3461,6 @@ def test_forward_dynamic_args():
|
|||
|
||||
# Make sure model runs without any exception
|
||||
for i in range(2):
|
||||
|
||||
# Test both train and inference mode
|
||||
if i % 2 == 0:
|
||||
model.train()
|
||||
|
|
@ -3506,7 +3492,6 @@ def test_forward_dynamic_args():
|
|||
|
||||
|
||||
def test_forward_dynamic_kwargs():
|
||||
|
||||
os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED"
|
||||
|
||||
one = torch.FloatTensor([1])
|
||||
|
|
@ -3515,7 +3500,6 @@ def test_forward_dynamic_kwargs():
|
|||
|
||||
# Make sure model runs without any exception
|
||||
for i in range(2):
|
||||
|
||||
# Test both train and inference mode
|
||||
if i % 2 == 0:
|
||||
model.train()
|
||||
|
|
@ -3562,46 +3546,48 @@ def test_forward_dynamic_kwargs():
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"forward_statement",
|
||||
"forward_function",
|
||||
[ # Only pos_X, pos_X as positionals
|
||||
"model(pos_0, pos_1)",
|
||||
lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0, pos_1),
|
||||
# Only pos_X, pos_X as keywords
|
||||
"model(pos_0=pos_0, pos_1=pos_1)",
|
||||
lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0=pos_0, pos_1=pos_1),
|
||||
# pos_X + *args, pos_X as positionals
|
||||
"model(pos_0, pos_1, *args)",
|
||||
lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0, pos_1, *args),
|
||||
# pos_X + kw_X, pos_X as positionals
|
||||
"model(pos_0, pos_1, kw_0=kw_0, kw_1=kw_1)",
|
||||
lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0, pos_1, kw_0=kw_0, kw_1=kw_1),
|
||||
# pos_X + kw_X, pos_X as keywords
|
||||
"model(pos_0=pos_0, pos_1=pos_1, kw_0=kw_0, kw_1=kw_1)",
|
||||
lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0=pos_0, pos_1=pos_1, kw_0=kw_0, kw_1=kw_1),
|
||||
# pos_X + kw_X, pos_X as positionals (missing kw_1)
|
||||
"model(pos_0, pos_1, kw_0=kw_0)",
|
||||
lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0, pos_1, kw_0=kw_0),
|
||||
# pos_X + kw_X, pos_X as keywords (missing kw_1)
|
||||
"model(pos_0=pos_0, pos_1=pos_1, kw_0=kw_0)",
|
||||
lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0=pos_0, pos_1=pos_1, kw_0=kw_0),
|
||||
# pos_X + kw_X, pos_X as positionals (missing kw_0)
|
||||
"model(pos_0, pos_1, kw_1=kw_1)",
|
||||
lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0, pos_1, kw_1=kw_1),
|
||||
# pos_X + kw_X, pos_X as keywords (missing kw_0)
|
||||
"model(pos_0=pos_0, pos_1=pos_1, kw_1=kw_1)",
|
||||
lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0=pos_0, pos_1=pos_1, kw_1=kw_1),
|
||||
# pos_X + kwargs, pos_X as positionals
|
||||
"model(pos_0, pos_1, **kwargs)",
|
||||
lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0, pos_1, **kwargs),
|
||||
# pos_X + kwargs, pos_X as keywords
|
||||
"model(pos_0=pos_0, pos_1=pos_1, **kwargs)",
|
||||
lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0=pos_0, pos_1=pos_1, **kwargs),
|
||||
# pos_X + *args + kw_X, pos_X as positionals
|
||||
"model(pos_0, pos_1, *args, kw_0=kw_0, kw_1=kw_1)",
|
||||
lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0, pos_1, *args, kw_0=kw_0, kw_1=kw_1),
|
||||
# pos_X + *args + kw_X, pos_X as positionals (missing kw_0)
|
||||
"model(pos_0, pos_1, *args, kw_1=kw_1)",
|
||||
lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0, pos_1, *args, kw_1=kw_1),
|
||||
# pos_X + *args + kw_X, pos_X as positionals (missing kw_1)
|
||||
"model(pos_0, pos_1, *args, kw_0=kw_0)",
|
||||
lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0, pos_1, *args, kw_0=kw_0),
|
||||
# pos_X + *args + kwargs, pos_X as positionals
|
||||
"model(pos_0, pos_1, *args, **kwargs)",
|
||||
lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0, pos_1, *args, **kwargs),
|
||||
# pos_X + *args + kw_X + kwargs, pos_X as positionals
|
||||
"model(pos_0, pos_1, *args, kw_0=kw_0, kw_1=kw_1, **kwargs)",
|
||||
lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(
|
||||
pos_0, pos_1, *args, kw_0=kw_0, kw_1=kw_1, **kwargs
|
||||
),
|
||||
# pos_X + *args + kw_X + kwargs, pos_X as positionals (missing kw_0)
|
||||
"model(pos_0, pos_1, *args, kw_1=kw_1, **kwargs)",
|
||||
lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0, pos_1, *args, kw_1=kw_1, **kwargs),
|
||||
# pos_X + *args + kw_X + kwargs, pos_X as positionals (missing kw_1)
|
||||
"model(pos_0, pos_1, *args, kw_0=kw_0, **kwargs)",
|
||||
lambda model, pos_0, pos_1, kw_0, kw_1, args, kwargs: model(pos_0, pos_1, *args, kw_0=kw_0, **kwargs),
|
||||
],
|
||||
)
|
||||
def test_forward_call_kwargs_input(forward_statement):
|
||||
def test_forward_call_kwargs_input(forward_function):
|
||||
class KwargsNet(torch.nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_classes):
|
||||
super(KwargsNet, self).__init__()
|
||||
|
|
@ -3644,7 +3630,7 @@ def test_forward_call_kwargs_input(forward_statement):
|
|||
kwargs = {"kwargs_0": torch.randn(N, D_in, device=device), "kwargs_1": torch.randn(D_in, D_in, device=device)}
|
||||
|
||||
# Training step
|
||||
prediction = eval(forward_statement)
|
||||
prediction = forward_function(model, pos_0, pos_1, kw_0, kw_1, args, kwargs)
|
||||
assert prediction is not None
|
||||
prediction = prediction.sum()
|
||||
prediction.backward()
|
||||
|
|
@ -3669,7 +3655,6 @@ def test_repro_iscontiguous():
|
|||
|
||||
|
||||
def test_forward_call_default_input():
|
||||
|
||||
os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED"
|
||||
|
||||
class UnusedNet(torch.nn.Module):
|
||||
|
|
@ -3795,7 +3780,6 @@ def test_forward_call_kwargs_input_unexpected_order():
|
|||
|
||||
|
||||
def test_forward_call_lots_None():
|
||||
|
||||
os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED"
|
||||
|
||||
class NoneNet(torch.nn.Module):
|
||||
|
|
@ -3943,7 +3927,6 @@ def test_primitive_inputs(bool_argument, int_argument, float_argument):
|
|||
|
||||
@pytest.mark.parametrize("bool_arguments", [(True, False), (False, True)])
|
||||
def test_changing_bool_input_re_exports_model(bool_arguments):
|
||||
|
||||
os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED"
|
||||
|
||||
class PrimitiveTypesInputNet(torch.nn.Module):
|
||||
|
|
@ -4116,7 +4099,6 @@ def test_output_order():
|
|||
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu", None])
|
||||
def test_stateless_model_specified_device(device):
|
||||
|
||||
N, D_in, H, D_out = 32, 784, 500, 10
|
||||
pt_model = StatelessModel().to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
|
|
@ -4131,7 +4113,6 @@ def test_stateless_model_specified_device(device):
|
|||
|
||||
|
||||
def test_stateless_model_unspecified_device():
|
||||
|
||||
N, D_in, H, D_out = 32, 784, 500, 10
|
||||
pt_model = StatelessModel()
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
|
|
@ -4238,7 +4219,6 @@ def test_hf_save_pretrained():
|
|||
|
||||
|
||||
def test_ortmodule_string_inputs_are_ignored():
|
||||
|
||||
pt_model = MyStrNet()
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
x = torch.randn(1, 2)
|
||||
|
|
@ -4346,7 +4326,6 @@ def test_ortmodule_nested_list_input():
|
|||
|
||||
@pytest.mark.parametrize("mode", ["training", "inference"])
|
||||
def test_debug_options_save_onnx_models_os_environment(mode):
|
||||
|
||||
device = "cuda"
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
# Create a temporary directory for the onnx_models
|
||||
|
|
@ -4370,7 +4349,6 @@ def test_debug_options_save_onnx_models_os_environment(mode):
|
|||
|
||||
@pytest.mark.parametrize("mode", ["training", "inference"])
|
||||
def test_debug_options_save_onnx_models_cwd(mode):
|
||||
|
||||
device = "cuda"
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
|
||||
|
|
@ -4395,7 +4373,6 @@ def test_debug_options_save_onnx_models_cwd(mode):
|
|||
|
||||
|
||||
def test_debug_options_save_onnx_models_validate_fail_on_non_writable_dir():
|
||||
|
||||
os.environ["ORTMODULE_SAVE_ONNX_PATH"] = "/non/existent/directory"
|
||||
with pytest.raises(Exception) as ex_info:
|
||||
_ = DebugOptions(save_onnx=True, onnx_prefix="my_model")
|
||||
|
|
@ -4793,7 +4770,6 @@ def test_ortmodule_setattr_ortmodule_attribute():
|
|||
|
||||
|
||||
def test_ortmodule_setattr_signals_model_changed():
|
||||
|
||||
os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED"
|
||||
|
||||
class UserNet(torch.nn.Module):
|
||||
|
|
@ -4928,7 +4904,6 @@ def test_ortmodule_skip_check_load_from_os_env(policy_str, policy):
|
|||
|
||||
@pytest.mark.parametrize("is_training,deterministic", list(itertools.product([True, False], repeat=2)))
|
||||
def test_ortmodule_determinism_flag(is_training, deterministic):
|
||||
|
||||
torch.use_deterministic_algorithms(deterministic)
|
||||
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
|
|
@ -4940,9 +4915,7 @@ def test_ortmodule_determinism_flag(is_training, deterministic):
|
|||
x = torch.randn(N, D_in)
|
||||
_ = model(x)
|
||||
|
||||
from onnxruntime.training.ortmodule import _are_deterministic_algorithms_enabled
|
||||
|
||||
assert _are_deterministic_algorithms_enabled() is torch.are_deterministic_algorithms_enabled()
|
||||
assert ortmodule_module._are_deterministic_algorithms_enabled() is torch.are_deterministic_algorithms_enabled()
|
||||
|
||||
|
||||
def test_ortmodule_gradient_builder():
|
||||
|
|
@ -5053,7 +5026,6 @@ def test_override_pytorch_exporter_kwargs_using_ortmodule_extension():
|
|||
|
||||
|
||||
def test_ortmodule_fused_adam_optimizer_correctness():
|
||||
|
||||
torch.manual_seed(8888)
|
||||
|
||||
device = "cuda"
|
||||
|
|
@ -5102,7 +5074,6 @@ def test_ortmodule_fused_adam_optimizer_correctness():
|
|||
|
||||
|
||||
def test_ortmodule_fused_adam_optimizer_correctness_torch():
|
||||
|
||||
torch.manual_seed(8888)
|
||||
|
||||
device = "cuda"
|
||||
|
|
@ -5225,13 +5196,11 @@ def test_tanh_grad():
|
|||
|
||||
|
||||
def test__defined_from_envvar():
|
||||
from onnxruntime.training import ortmodule
|
||||
|
||||
os.environ["DUMMY_ORTMODULE"] = "15"
|
||||
assert ortmodule._defined_from_envvar("DUMMY_ORTMODULE", 14) == 15
|
||||
assert ortmodule_module._defined_from_envvar("DUMMY_ORTMODULE", 14) == 15
|
||||
os.environ["DUMMY_ORTMODULE"] = "15j"
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
assert ortmodule._defined_from_envvar("DUMMY_ORTMODULE", 14) == 14
|
||||
assert ortmodule_module._defined_from_envvar("DUMMY_ORTMODULE", 14) == 14
|
||||
assert len(w) == 1
|
||||
assert issubclass(w[-1].category, UserWarning)
|
||||
assert "Unable to overwrite constant" in str(w[-1].message)
|
||||
|
|
@ -5262,12 +5231,10 @@ def test_sigmoid_grad_opset13():
|
|||
N, D_in, H, D_out = 120, 15360, 500, 15360
|
||||
pt_model = NeuralNetSigmoid(D_in, H, D_out).to(device)
|
||||
|
||||
from onnxruntime.training import ortmodule
|
||||
|
||||
old_opst_cst = ortmodule.ONNX_OPSET_VERSION
|
||||
old_opst_cst = ortmodule_module.ONNX_OPSET_VERSION
|
||||
old_opset = os.getenv("ORTMODULE_ONNX_OPSET_VERSION", None)
|
||||
os.environ["ORTMODULE_ONNX_OPSET_VERSION"] = "13"
|
||||
assert ortmodule.ONNX_OPSET_VERSION == 15
|
||||
assert ortmodule_module.ONNX_OPSET_VERSION == 15
|
||||
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
|
||||
|
|
@ -5293,8 +5260,8 @@ def test_sigmoid_grad_opset13():
|
|||
del os.environ["ORTMODULE_ONNX_OPSET_VERSION"]
|
||||
else:
|
||||
os.environ["ORTMODULE_ONNX_OPSET_VERSION"] = old_opset
|
||||
assert ortmodule.ONNX_OPSET_VERSION == 13
|
||||
ortmodule.ONNX_OPSET_VERSION = old_opst_cst
|
||||
assert ortmodule_module.ONNX_OPSET_VERSION == 13
|
||||
ortmodule_module.ONNX_OPSET_VERSION = old_opst_cst
|
||||
|
||||
|
||||
@pytest.mark.parametrize("opset_version", [12, 13, 14, 15])
|
||||
|
|
@ -5307,10 +5274,7 @@ def test_opset_version_change(opset_version):
|
|||
|
||||
ort_model = ORTModule(model)
|
||||
|
||||
# Must import a namespace containing ONNX_OPSET_VERSION, not ONNX_OPSET_VERSION directly
|
||||
from onnxruntime.training import ortmodule
|
||||
|
||||
ortmodule.ONNX_OPSET_VERSION = opset_version
|
||||
ortmodule_module.ONNX_OPSET_VERSION = opset_version
|
||||
|
||||
# Make sure model runs without any exception
|
||||
prediction = ort_model(x)
|
||||
|
|
@ -5324,7 +5288,6 @@ def test_opset_version_change(opset_version):
|
|||
|
||||
|
||||
def test_serialize_ortmodule():
|
||||
|
||||
device = "cuda"
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
pt_model = SerializationNet(D_in, H, D_out).to(device)
|
||||
|
|
@ -5455,8 +5418,6 @@ def test_check_opset_is_default_opset_after_training():
|
|||
|
||||
|
||||
def test_random_states_unchanged_for_ortmodule():
|
||||
import numpy
|
||||
|
||||
os.environ["ORTMODULE_FALLBACK_RETRY"] = "False"
|
||||
|
||||
class NeuralNetSlice(torch.nn.Module):
|
||||
|
|
@ -5473,8 +5434,8 @@ def test_random_states_unchanged_for_ortmodule():
|
|||
if isinstance(a, tuple):
|
||||
assert len(a) == len(b)
|
||||
return all([random_state_equal(a_i, b_i) for a_i, b_i in zip(a, b)])
|
||||
if isinstance(a, numpy.ndarray):
|
||||
return numpy.array_equal(a, b)
|
||||
if isinstance(a, np.ndarray):
|
||||
return np.array_equal(a, b)
|
||||
if isinstance(a, torch.Tensor):
|
||||
return torch.equal(a, b)
|
||||
return a == b
|
||||
|
|
|
|||
Loading…
Reference in a new issue