Change MinGrad/MaxGrad to Use Distributed Logic (#11388)

* change min max grad

* resolve comments
This commit is contained in:
Vincent Wang 2022-05-05 11:49:40 +08:00 committed by GitHub
parent 860ba8820b
commit 084165c748
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 92 additions and 63 deletions

View file

@ -1044,7 +1044,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceMeanGradient) {
result.push_back(axes_values_node);
result.push_back(NodeDef(OpDef{"Unsqueeze", kOnnxDomain, 13}, {GO(0), axes_values_node.output_args[0]}, {grad}));
}
}
}
}
result.push_back(NodeDef("Size", {I(0)}, {IA("Sized_X")}));
@ -1597,55 +1597,52 @@ IMPLEMENT_GRADIENT_BUILDER(GetTileGradient) {
}
IMPLEMENT_GRADIENT_BUILDER(GetMinMaxGradient) {
const auto num_src_node_inputs = GetSrcNodeInputSize();
if (num_src_node_inputs == 1) {
if (IsGradientRequiredForSrcNodeInput(0)) {
return std::vector<NodeDef>{NodeDef("Identity", {GO(0)}, {GI(0)})};
const size_t num_src_node_inputs = static_cast<size_t>(GetSrcNodeInputSize());
bool has_gradient_required = false;
for (size_t i = 0; i < num_src_node_inputs; ++i) {
if (IsGradientRequiredForSrcNodeInput(i)) {
has_gradient_required = true;
break;
}
}
if (!has_gradient_required) {
return std::vector<NodeDef>{};
}
if (num_src_node_inputs > 2) {
ORT_THROW("Min/Max gradient currently does not support over 2 inputs.");
}
if (!IsGradientRequiredForSrcNodeInput(0) && !IsGradientRequiredForSrcNodeInput(1)) {
return std::vector<NodeDef>{};
if (num_src_node_inputs == 1) {
return std::vector<NodeDef>{NodeDef("Identity", {GO(0)}, {GI(0)})};
}
std::vector<NodeDef> result;
std::vector<Dimension> y_shape;
const ArgDef y = O(0);
bool get_y_shape_ok = GetShape(y, y_shape).IsOK();
result.push_back(NodeDef("Equal", {I(1), y}, {IA("Mask_1")}));
if (IsGradientRequiredForSrcNodeInput(0)) {
result.push_back(NodeDef("Not", {IA("Mask_1")}, {IA("Mask_0")}));
std::vector<ArgDef> sum_inputs;
for (size_t i = 0; i < num_src_node_inputs; ++i) {
const ArgDef mask = IA("Mask_" + std::to_string(i));
const ArgDef mask_cast = IA("Mask_Cast_" + std::to_string(i));
result.emplace_back(NodeDef("Equal", {I(i), y}, {mask}));
result.emplace_back(NodeDef("Cast", {mask}, {mask_cast}, {MakeAttribute("to", int64_t(IElemType(0)))}));
sum_inputs.emplace_back(mask_cast);
}
const ArgDef a = I(0), b = I(1);
for (int i = 0; i < num_src_node_inputs; i++) {
const ArgDef dy_scaled = IA("dY_Scaled");
result.emplace_back(NodeDef("Sum", sum_inputs, {IA("Scale")}));
result.emplace_back(NodeDef("Div", {GO(0), IA("Scale")}, {dy_scaled}));
std::vector<Dimension> y_shape;
bool has_y_shape = GetShape(y, y_shape).IsOK();
for (size_t i = 0; i < num_src_node_inputs; ++i) {
if (IsGradientRequiredForSrcNodeInput(i)) {
const ArgDef x = I(i);
const ArgDef mask_cast_i_def = IA("Mask_Cast_" + std::to_string(i));
const ArgDef pre_reduce_grad_i_def = IA("PreReduceGrad_" + std::to_string(i), OType(0));
result.push_back(NodeDef("Cast",
{IA("Mask_" + std::to_string(i))},
{mask_cast_i_def},
{MakeAttribute("to", int64_t(IElemType(0)))}));
result.push_back(NodeDef("Mul", {mask_cast_i_def, GO(0)}, {pre_reduce_grad_i_def}));
if (a.name.compare(b.name) == 0) {
result.push_back(NodeDef("Identity", {pre_reduce_grad_i_def}, {GI(i)}));
continue;
}
result.emplace_back(NodeDef("Mul", {dy_scaled, IA("Mask_Cast_" + std::to_string(i))}, {pre_reduce_grad_i_def}));
std::vector<Dimension> x_shape;
if (get_y_shape_ok && GetShape(x, x_shape).IsOK()) {
if (has_y_shape && GetShape(x, x_shape).IsOK()) {
std::vector<int64_t> x_axes;
ComputeBroadcastBackwardAxes(x_shape, y_shape, &x_axes, nullptr, NodeName());
if (x_axes.size() > 0) {
if (!x_axes.empty()) {
HandleBroadcasting(pre_reduce_grad_i_def, x, GI(i), x_axes, result);
} else {
result.push_back(NodeDef("Identity", {pre_reduce_grad_i_def}, {GI(i)}));
result.emplace_back(NodeDef("Identity", {pre_reduce_grad_i_def}, {GI(i)}));
}
} else {
ArgDef x_axes_def = IA("ReduceAxes_" + x.name);

View file

@ -2519,10 +2519,12 @@ void GradientCheckerMinMaxGradHelper(const std::string op) {
float max_error;
GradientChecker<float, float, float> gradient_checker;
OpDef op_def{op, kOnnxDomain, 11};
// Ensure the gap between x1 and x2 is greater than 1e-3f, otherwise the result of NumericJacobian
// Ensure the gap between tensors is greater than 1e-3f, otherwise the result of NumericJacobian
// will be incorrect. This also excludes equal inputs case, where Min/Max is not smooth.
std::function<float(float)> x1_transformer = [](float x) { return (int)(x * 100) / 100.f; };
std::function<float(float)> x2_transformer = [](float x) { return (int)(x * 100) / 100.f + 0.002f; };
std::function<float(float)> x3_transformer = [](float x) { return (int)(x * 100) / 100.f + 0.004f; };
std::function<float(float)> x4_transformer = [](float x) { return (int)(x * 100) / 100.f + 0.006f; };
TensorInfo x1_info({2, 3}, true, &x1_transformer);
TensorInfo y_info({2, 3}, true);
@ -2543,6 +2545,23 @@ void GradientCheckerMinMaxGradHelper(const std::string op) {
ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {x1_info, x2_info}, {y_info}, &max_error));
EXPECT_IS_TINY(max_error);
}
// More than 2 inputs.
{
TensorInfo x2_info({2, 3}, true, &x2_transformer);
TensorInfo x3_info({2, 3}, true, &x3_transformer);
ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {x1_info, x2_info, x3_info}, {y_info}, &max_error));
EXPECT_IS_TINY(max_error);
}
{
TensorInfo x2_info({3}, true, &x2_transformer);
TensorInfo x3_info({2, 1}, true, &x3_transformer);
TensorInfo x4_info({2, 3}, true, &x4_transformer);
ASSERT_STATUS_OK(
gradient_checker.ComputeGradientError(op_def, {x1_info, x2_info, x3_info, x4_info}, {y_info}, &max_error));
EXPECT_IS_TINY(max_error);
}
}
TEST(GradientCheckerTest, MinGrad) {

View file

@ -2,44 +2,41 @@
# Licensed under the MIT License.
# orttraining_test_ortmodule_api.py
import copy
import itertools
import math
import random
import copy
import torch
from transformers import AutoConfig, BertForSequenceClassification, Trainer
from transformers.modeling_outputs import SequenceClassifierOutput
import pytest
from time import sleep
import warnings
from unittest.mock import patch
from collections import OrderedDict
from collections import namedtuple
from inspect import signature
import tempfile
import os
import pickle
import random
import tempfile
import warnings
from collections import OrderedDict, namedtuple
from distutils.version import LooseVersion
from onnxruntime.training.ortmodule._custom_gradient_registry import register_gradient
from onnxruntime.training.ortmodule import (
ORTModule,
_utils,
_io,
DebugOptions,
LogLevel,
_fallback,
_graph_execution_manager,
)
import onnxruntime.training.ortmodule as ortmodule_module
from onnxruntime.training.optim import FusedAdam, AdamWMode
from transformers import AdamW
from inspect import signature
from time import sleep
from unittest.mock import patch
import _test_helpers
import pytest
import torch
# Import autocasting libs
from torch.cuda import amp
from transformers import AdamW, AutoConfig, BertForSequenceClassification, Trainer
from transformers.modeling_outputs import SequenceClassifierOutput
import onnxruntime.training.ortmodule as ortmodule_module
from onnxruntime.training.optim import AdamWMode, FusedAdam
from onnxruntime.training.ortmodule import (
DebugOptions,
LogLevel,
ORTModule,
_fallback,
_graph_execution_manager,
_io,
_utils,
)
from onnxruntime.training.ortmodule._custom_gradient_registry import register_gradient
DEFAULT_OPSET = 14
@ -1123,9 +1120,12 @@ def test_gradient_correctness_max(operator, dim, keepdim):
_test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad)
@pytest.mark.skip("temporarily disabled due to PyTorch's change of max implementation")
# Before 1.10 (excluded), Torch's min/max(x,y) will assign dY to y's dX if value from x and y are equal.
# From 1.10, both x and y's dX will be dY/2. ORT follows this distribution logic, so skip below test if Torch version
# is lower than 1.10.
@pytest.mark.skipif(LooseVersion(torch.__version__) < LooseVersion("1.10.0"), reason="PyTorch 1.9 incompatible")
@pytest.mark.parametrize("operator", ["min", "max"])
def test_gradient_correctness_max_two_tensors(operator):
def test_gradient_correctness_minmax_two_tensors(operator):
func = getattr(torch, operator)
class NeuralNetMaxTwoTensors(torch.nn.Module):
@ -1153,6 +1153,19 @@ def test_gradient_correctness_max_two_tensors(operator):
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
_test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad)
_test_helpers.assert_values_are_close(ort_other.grad, pt_other.grad)
# Simple test for case that has equal value.
pt_input = torch.tensor([0.0, 0.0, 1.0, 1.0], device=device, requires_grad=True)
pt_other = torch.tensor([1.0, 0.0, 1.0, 0.0], device=device, requires_grad=True)
ort_input = copy.deepcopy(pt_input)
ort_other = copy.deepcopy(pt_other)
pt_prediction = run_step(pt_model, pt_input, pt_other)
ort_prediction = run_step(ort_model, ort_input, ort_other)
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
_test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad)
_test_helpers.assert_values_are_close(ort_other.grad, pt_other.grad)
def test_gradient_correctness_argmax_unfold():
@ -2711,7 +2724,7 @@ def test_forward_data_and_model_on_different_devices(data_device, model_device):
# Now that the model has been exported, feed in data from device other than the model device
x = torch.randn(N, D_in, device=data_device)
from onnxruntime.training.ortmodule._fallback import _FallbackPolicy, ORTModuleDeviceException
from onnxruntime.training.ortmodule._fallback import ORTModuleDeviceException, _FallbackPolicy
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE):
# Fallback