onnxruntime/orttraining/orttraining/test/python/_test_helpers.py
pengwa fbf5d09a0c
Fix random failure of ortmodule_api.py::test_unused_parameters (#14729)
### Fix random failure of ortmodule_api.py::test_unused_parameters

Fix FAILED
orttraining_test_ortmodule_api.py::test_unused_parameters[model1-none_pt_params1]
for orttraining-linux-gpu-ci-pipeline CI pipeline

```
=================================== FAILURES ===================================
________________ test_unused_parameters[model1-none_pt_params1] ________________

model = UnusedMiddleParameterNet(
  (fc1): Linear(in_features=784, out_features=500, bias=True)
  (relu): ReLU()
  (fc2): Linear(in_features=500, out_features=400, bias=True)
  (fc3): Linear(in_features=500, out_features=10, bias=True)
)
none_pt_params = ['fc2.weight', 'fc2.bias']

    @pytest.mark.parametrize(
        "model, none_pt_params",
        [
            (UnusedBeginParameterNet(784, 500, 400, 10), ["fc1.weight", "fc1.bias"]),
            (UnusedMiddleParameterNet(784, 500, 400, 10), ["fc2.weight", "fc2.bias"]),
            (UnusedEndParameterNet(784, 500, 400, 10), ["fc2.weight", "fc2.bias"]),
        ],
    )
    def test_unused_parameters(model, none_pt_params):
        device = "cuda"
    
        N, D_in, H1, H2, D_out = 64, 784, 500, 400, 10
        model = model.to(device)
        ort_model = ORTModule(copy.deepcopy(model))
    
        # Make sure model runs without any exception
        for _ in range(5):
            x = torch.randn(N, D_in, device=device)
            y = copy.deepcopy(x)
    
            out_pt = model(x)
            out_ort = ort_model(y)
            loss_pt = out_pt.sum()
            loss_pt.backward()
            loss_ort = out_ort.sum()
            loss_ort.backward()
            _test_helpers.assert_values_are_close(out_ort, out_pt)
>           _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, model, none_pt_params=none_pt_params)

orttraining_test_ortmodule_api.py:4050: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
_test_helpers.py:216: in assert_gradients_match_and_reset_gradient
    assert_values_are_close(ort_param.grad, pt_param.grad, rtol=rtol, atol=atol)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

```

Initially the test runs very well. As we insert more and more tests,
when running ortmodule_api.py::test_unused_parameters, the random
generated data got changed, and now it is more easily to generate an
input data that produce a result the break existing rtol and atol.

The example data, 0.1041 only have very minor diff, e.g. abs_diff:
2.2649765014648438e-06.
> The torch.allclose judge it is not equal because: abs_diff> 0.1041 *
rtol + atol = 1.041e-1 * 1e-5 + 1e-6 =-2.041e-6.
> Additionally, according to math
[here](7b31bcda2e/orttraining/orttraining/test/python/_test_helpers.py (L230))
The maximum atol is 1.2238311910550692e-06 > current atol(1e-6), maximum
rtol is 1.2149855137977283e-05 > current rtol(1e-5).

This PR looses the atol to 1e-5, rtol to 1e-4 .
2023-02-20 18:09:53 +08:00

446 lines
16 KiB
Python

import copy
import os
import numpy as np
import torch
from numpy.testing import assert_allclose
from onnxruntime.capi.ort_trainer import ORTTrainer as Legacy_ORTTrainer
from onnxruntime.training import orttrainer
try:
from onnxruntime.training.ortmodule import ORTModule
from onnxruntime.training.ortmodule._fallback import ORTModuleInitException
from onnxruntime.training.ortmodule._graph_execution_manager_factory import GraphExecutionManagerFactory
except ImportError:
# Some pipelines do not contain ORTModule
pass
except Exception as e:
from onnxruntime.training.ortmodule._fallback import ORTModuleInitException
if isinstance(e, ORTModuleInitException):
# ORTModule is present but not ready to run
# That is OK because this file is also used by ORTTrainer tests
pass
raise
def is_all_or_nothing_fallback_enabled(model, policy=None):
from onnxruntime.training.ortmodule import ORTMODULE_FALLBACK_POLICY
from onnxruntime.training.ortmodule._fallback import _FallbackPolicy
if os.getenv("ORTMODULE_FALLBACK_POLICY") == _FallbackPolicy.FALLBACK_DISABLE.name:
return False
if not policy:
policy = _FallbackPolicy.FALLBACK_DISABLE
fallback_on_env = policy in ORTMODULE_FALLBACK_POLICY
fallback_on_model = False
if model:
fallback_on_model = (
policy in model._torch_module._execution_manager(is_training=True)._fallback_manager.policy
or policy in model._torch_module._execution_manager(is_training=False)._fallback_manager.policy
)
return fallback_on_env or fallback_on_model
def assert_model_outputs(output_a, output_b, verbose=False, rtol=1e-7, atol=0):
r"""Asserts whether output_a and output_b difference is within specified tolerance
Args:
output_a, output_b (list): Two list with of numeric values
verbose (bool, default is False): if True, prints absolute difference for each weight
rtol (float, default is 1e-7): Max relative difference
atol (float, default is 1e-4): Max absolute difference
"""
assert isinstance(output_a, list) and isinstance(output_b, list), "output_a and output_b must be list of numbers"
if len(output_a) != len(output_b):
raise AssertionError(
"output_a and output_b must have the same length (%r != %r)." % (len(output_a), len(output_b))
)
# for idx in range(len(output_a)):
assert_allclose(output_a, output_b, rtol=rtol, atol=atol, err_msg=f"Model output value mismatch")
def assert_onnx_weights(model_a, model_b, verbose=False, rtol=1e-7, atol=0):
r"""Asserts whether weight difference between models a and b differences are within specified tolerance
Compares the weights of two different ONNX models (model_a and model_b)
and raises AssertError when they diverge by more than atol or rtol
Args:
model_a, model_b (ORTTrainer): Two instances of ORTTrainer with the same model structure
verbose (bool, default is False): if True, prints absolute difference for each weight
rtol (float, default is 1e-7): Max relative difference
atol (float, default is 1e-4): Max absolute difference
"""
assert isinstance(model_a, orttrainer.ORTTrainer) and isinstance(model_b, orttrainer.ORTTrainer)
state_dict_a, state_dict_b = model_a._training_session.get_state(), model_b._training_session.get_state()
assert len(state_dict_a.items()) == len(state_dict_b.items())
_assert_state_dict_weights(state_dict_a, state_dict_b, verbose, rtol, atol)
def assert_legacy_onnx_weights(model_a, model_b, verbose=False, rtol=1e-7, atol=0):
r"""Asserts whether weight difference between models a and b differences are within specified tolerance
Compares the weights of a legacy model model_a and experimental model_b model
and raises AssertError when they diverge by more than atol or rtol.
Args:
model_a (ORTTrainer): Instance of legacy ORTTrainer
model_b (ORTTrainer): Instance of experimental ORTTrainer
verbose (bool, default is False): if True, prints absolute difference for each weight.
rtol (float, default is 1e-7): Max relative difference
atol (float, default is 1e-4): Max absolute difference
"""
assert isinstance(model_a, orttrainer.ORTTrainer) and isinstance(model_b, Legacy_ORTTrainer)
state_dict_a, state_dict_b = model_a._training_session.get_state(), model_b.session.get_state()
assert len(state_dict_a.items()) == len(state_dict_b.items())
_assert_state_dict_weights(state_dict_a, state_dict_b, verbose, rtol, atol)
def _assert_state_dict_weights(state_dict_a, state_dict_b, verbose, rtol, atol):
r"""Asserts whether dicts a and b value differences are within specified tolerance
Compares the weights of two model's state_dict dicts and raises AssertError
when they diverge by more than atol or rtol
Args:
model_a (ORTTrainer): Instance of legacy ORTTrainer
model_b (ORTTrainer): Instance of experimental ORTTrainer
verbose (bool, default is False): if True, prints absolute difference for each weight.
rtol (float, default is 1e-7): Max relative difference
atol (float, default is 1e-4): Max absolute difference
"""
for (a_name, a_val), (b_name, b_val) in zip(state_dict_a.items(), state_dict_b.items()):
np_a_vals = np.array(a_val).flatten()
np_b_vals = np.array(b_val).flatten()
assert np_a_vals.shape == np_b_vals.shape
if verbose:
print(f"Weight name: {a_name}: absolute difference: {np.abs(np_a_vals-np_b_vals).max()}")
assert_allclose(a_val, b_val, rtol=rtol, atol=atol, err_msg=f"Weight mismatch for {a_name}")
def assert_optim_state(expected_state, actual_state, rtol=1e-7, atol=0):
r"""Asserts whether optimizer state differences are within specified tolerance
Compares the expected and actual optimizer states of dicts and raises AssertError
when they diverge by more than atol or rtol.
The optimizer dict is of the form:
model_weight_name:
{
"Moment_1": moment1_tensor,
"Moment_2": moment2_tensor,
"Update_Count": update_tensor # if optimizer is adam, absent otherwise
},
...
"shared_optimizer_state": # if optimizer is shared, absent otherwise.
So far, only lamb optimizer uses this.
{
"step": step_tensor # int array of size 1
}
Args:
expected_state (dict(dict())): Expected optimizer state
actual_state (dict(dict())): Actual optimizer state
rtol (float, default is 1e-7): Max relative difference
atol (float, default is 0): Max absolute difference
"""
assert expected_state.keys() == actual_state.keys()
for param_name, a_state in actual_state.items():
for k, v in a_state.items():
assert_allclose(
v,
expected_state[param_name][k],
rtol=rtol,
atol=atol,
err_msg=f"Optimizer state mismatch for param {param_name}, key {k}",
)
def is_dynamic_axes(model):
# Check inputs
for inp in model._torch_module._execution_manager(model._is_training())._onnx_models.optimized_model.graph.input:
shape = inp.type.tensor_type.shape
if shape:
for dim in shape.dim:
if dim.dim_param and not isinstance(dim.dim_param, str):
return False
# Check outputs
for out in model._torch_module._execution_manager(model._is_training())._onnx_models.optimized_model.graph.output:
shape = out.type.tensor_type.shape
if shape:
for dim in shape.dim:
if dim.dim_param and not isinstance(dim.dim_param, str):
return False
return True
# TODO: thiagofc: Checkpoint related for redesign
def _get_name(name):
if os.path.exists(name):
return name
rel = os.path.join("testdata", name)
if os.path.exists(rel):
return rel
this = os.path.dirname(__file__)
data = os.path.join(this, "..", "testdata")
res = os.path.join(data, name)
if os.path.exists(res):
return res
raise FileNotFoundError("Unable to find '{0}' or '{1}' or '{2}'".format(name, rel, res))
# Depending on calling backward() from which outputs, it's possible that grad of some weights are not calculated.
# none_pt_params is to tell what these weights are, so we will not compare the tensors.
def assert_gradients_match_and_reset_gradient(
ort_model, pt_model, none_pt_params=[], reset_gradient=True, rtol=1e-04, atol=1e-05
):
ort_named_params = list(ort_model.named_parameters())
pt_named_params = list(pt_model.named_parameters())
assert len(ort_named_params) == len(pt_named_params)
for ort_named_param, pt_named_param in zip(ort_named_params, pt_named_params):
ort_name, ort_param = ort_named_param
pt_name, pt_param = pt_named_param
assert pt_name in ort_name
if pt_name in none_pt_params:
assert pt_param.grad is None
assert ort_param.grad is None or not torch.is_nonzero(torch.count_nonzero(ort_param.grad))
else:
assert_values_are_close(ort_param.grad, pt_param.grad, rtol=rtol, atol=atol)
if reset_gradient:
ort_param.grad = None
pt_param.grad = None
def assert_values_are_close(input, other, rtol=1e-04, atol=1e-05):
are_close = torch.allclose(input, other, rtol=rtol, atol=atol)
if not are_close:
abs_diff = torch.abs(input - other)
abs_other = torch.abs(other)
max_atol = torch.max((abs_diff - rtol * abs_other))
max_rtol = torch.max((abs_diff - atol) / abs_other)
err_msg = "The maximum atol is {}, maximum rtol is {}".format(max_atol, max_rtol)
assert False, err_msg
def _run_model_on_device(device, model, input_list, label_input, is_eval_mode=False, run_forward_twice=False):
if is_eval_mode:
model.eval()
else:
model.train()
def generate_inputs(input_list_, label_input_):
with torch.no_grad():
inputs_on_device = [input_.to(device) for input_ in input_list_]
for i, val in enumerate(input_list_):
if val.requires_grad:
inputs_on_device[i].requires_grad_()
with torch.no_grad():
target = label_input_.to(device)
return inputs_on_device, target
inputs_on_device1, target1 = generate_inputs(input_list, label_input)
if run_forward_twice is True:
inputs_on_device2, target2 = generate_inputs(input_list, label_input)
output1 = model(*inputs_on_device1)
if run_forward_twice is True:
output2 = model(*inputs_on_device2)
forward_outputs = [output1]
grad_outputs = []
if not is_eval_mode:
criterion = torch.nn.MSELoss()
loss = criterion(output1, target1)
if run_forward_twice is True:
loss += criterion(output2, target2)
loss.backward()
for name, param in model.named_parameters():
if param.requires_grad:
grad_outputs.append(param.grad)
return forward_outputs, grad_outputs
def run_with_pytorch_on_device(device, model, input_list, label_input, is_eval_mode=False, run_forward_twice=False):
with torch.no_grad():
model = copy.deepcopy(model).to(device)
return _run_model_on_device(device, model, input_list, label_input, is_eval_mode, run_forward_twice)
def run_with_ort_on_device(device, model, input_list, label_input, is_eval_mode=False, run_forward_twice=False):
with torch.no_grad():
model = copy.deepcopy(model)
model.to(device)
model = ORTModule(model)
return _run_model_on_device(device, model, input_list, label_input, is_eval_mode, run_forward_twice)
def compare_tensor_list(val_list_a, val_list_b):
for val_a, val_b in zip(val_list_a, val_list_b):
assert_values_are_close(val_a, val_b, atol=1e-7, rtol=1e-6)
def run_training_test_and_compare(
pt_model_builder_func,
pt_model_inputs_generator,
pt_model_label_input,
run_forward_twice=False,
ignore_grad_compare=False,
expected_outputs=[],
expected_grads=[],
):
cpu = torch.device("cpu")
def cpu_barrier_func():
pass
run_training_test_on_device_and_compare(
cpu,
pt_model_builder_func,
pt_model_inputs_generator,
pt_model_label_input,
cpu_barrier_func,
run_forward_twice,
ignore_grad_compare,
expected_outputs,
expected_grads,
)
def cuda_barrier_func():
torch.cuda.synchronize()
cuda = torch.device("cuda:0")
run_training_test_on_device_and_compare(
cuda,
pt_model_builder_func,
pt_model_inputs_generator,
pt_model_label_input,
cuda_barrier_func,
run_forward_twice,
ignore_grad_compare,
expected_outputs,
expected_grads,
)
def run_training_test_on_device_and_compare(
device,
pt_model_builder_func,
pt_model_inputs_generator,
pt_model_label_input,
barrier_func,
run_forward_twice=False,
ignore_grad_compare=False,
expected_outputs=[],
expected_grads=[],
):
repeats = 16
for i in range(repeats):
m = pt_model_builder_func()
x = pt_model_inputs_generator()
with torch.no_grad():
m_ort = copy.deepcopy(m)
x_ort = copy.deepcopy(x)
outputs, grads = run_with_pytorch_on_device(
device, m, [x], pt_model_label_input, run_forward_twice=run_forward_twice
)
barrier_func()
outputs_ort, grads_ort = run_with_ort_on_device(
device, m_ort, [x_ort], pt_model_label_input, run_forward_twice=run_forward_twice
)
barrier_func()
val_list_a = [o.detach().cpu() for o in outputs if o is not None]
val_list_b = [o.detach().cpu() for o in outputs_ort if o is not None]
compare_tensor_list(val_list_a, val_list_b)
if len(expected_outputs) > 0:
compare_tensor_list(val_list_a, expected_outputs)
# For some test, it is expected the diff might be big due to inconsistent computation orders.
if ignore_grad_compare is False:
val_list_a = [o.detach().cpu() for o in grads if o is not None]
val_list_b = [o.detach().cpu() for o in grads_ort if o is not None]
compare_tensor_list(val_list_a, val_list_b)
if len(expected_grads) > 0:
compare_tensor_list(val_list_a, expected_grads)
def run_evaluate_test_and_compare(
pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, run_forward_twice=False
):
cpu = torch.device("cpu")
def cpu_barrier_func():
pass
run_evaluate_test_on_device_and_compare(
cpu,
pt_model_builder_func,
pt_model_inputs_generator,
pt_model_label_input,
cpu_barrier_func,
run_forward_twice=run_forward_twice,
)
def cuda_barrier_func():
torch.cuda.synchronize()
pass
cuda = torch.device("cuda:0")
run_evaluate_test_on_device_and_compare(
cuda,
pt_model_builder_func,
pt_model_inputs_generator,
pt_model_label_input,
cuda_barrier_func,
run_forward_twice=run_forward_twice,
)
def run_evaluate_test_on_device_and_compare(
device,
pt_model_builder_func,
pt_model_inputs_generator,
pt_model_label_input,
barrier_func,
run_forward_twice=False,
):
repeats = 16
for i in range(repeats):
m = pt_model_builder_func()
x = pt_model_inputs_generator()
m_ort = copy.deepcopy(m)
x_ort = copy.deepcopy(x)
outputs, grads = run_with_pytorch_on_device(
device, m, [x], pt_model_label_input, is_eval_mode=True, run_forward_twice=run_forward_twice
)
barrier_func()
outputs_ort, grads_ort = run_with_ort_on_device(
device, m_ort, [x_ort], pt_model_label_input, is_eval_mode=True, run_forward_twice=run_forward_twice
)
barrier_func()
val_list_a = [o.detach().cpu() for o in outputs if o is not None]
val_list_b = [o.detach().cpu() for o in outputs_ort if o is not None]
compare_tensor_list(val_list_a, val_list_b)