From 084165c748bcd6fbb3fca4d60b1ece911c04c043 Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Thu, 5 May 2022 11:49:40 +0800 Subject: [PATCH] Change MinGrad/MaxGrad to Use Distributed Logic (#11388) * change min max grad * resolve comments --- .../core/graph/gradient_builder.cc | 61 ++++++++-------- .../test/gradient/gradient_ops_test.cc | 21 +++++- .../python/orttraining_test_ortmodule_api.py | 73 +++++++++++-------- 3 files changed, 92 insertions(+), 63 deletions(-) diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 7d3af69cf6..5fca1a81ba 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -1044,7 +1044,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceMeanGradient) { result.push_back(axes_values_node); result.push_back(NodeDef(OpDef{"Unsqueeze", kOnnxDomain, 13}, {GO(0), axes_values_node.output_args[0]}, {grad})); } - } + } } result.push_back(NodeDef("Size", {I(0)}, {IA("Sized_X")})); @@ -1597,55 +1597,52 @@ IMPLEMENT_GRADIENT_BUILDER(GetTileGradient) { } IMPLEMENT_GRADIENT_BUILDER(GetMinMaxGradient) { - const auto num_src_node_inputs = GetSrcNodeInputSize(); - if (num_src_node_inputs == 1) { - if (IsGradientRequiredForSrcNodeInput(0)) { - return std::vector{NodeDef("Identity", {GO(0)}, {GI(0)})}; + const size_t num_src_node_inputs = static_cast(GetSrcNodeInputSize()); + bool has_gradient_required = false; + for (size_t i = 0; i < num_src_node_inputs; ++i) { + if (IsGradientRequiredForSrcNodeInput(i)) { + has_gradient_required = true; + break; } + } + if (!has_gradient_required) { return std::vector{}; } - if (num_src_node_inputs > 2) { - ORT_THROW("Min/Max gradient currently does not support over 2 inputs."); - } - - if (!IsGradientRequiredForSrcNodeInput(0) && !IsGradientRequiredForSrcNodeInput(1)) { - return std::vector{}; + if (num_src_node_inputs == 1) { + return std::vector{NodeDef("Identity", {GO(0)}, {GI(0)})}; } std::vector result; - std::vector y_shape; const ArgDef y = O(0); - bool get_y_shape_ok = GetShape(y, y_shape).IsOK(); - result.push_back(NodeDef("Equal", {I(1), y}, {IA("Mask_1")})); - if (IsGradientRequiredForSrcNodeInput(0)) { - result.push_back(NodeDef("Not", {IA("Mask_1")}, {IA("Mask_0")})); + std::vector sum_inputs; + for (size_t i = 0; i < num_src_node_inputs; ++i) { + const ArgDef mask = IA("Mask_" + std::to_string(i)); + const ArgDef mask_cast = IA("Mask_Cast_" + std::to_string(i)); + result.emplace_back(NodeDef("Equal", {I(i), y}, {mask})); + result.emplace_back(NodeDef("Cast", {mask}, {mask_cast}, {MakeAttribute("to", int64_t(IElemType(0)))})); + sum_inputs.emplace_back(mask_cast); } - const ArgDef a = I(0), b = I(1); - for (int i = 0; i < num_src_node_inputs; i++) { + + const ArgDef dy_scaled = IA("dY_Scaled"); + result.emplace_back(NodeDef("Sum", sum_inputs, {IA("Scale")})); + result.emplace_back(NodeDef("Div", {GO(0), IA("Scale")}, {dy_scaled})); + std::vector y_shape; + bool has_y_shape = GetShape(y, y_shape).IsOK(); + for (size_t i = 0; i < num_src_node_inputs; ++i) { if (IsGradientRequiredForSrcNodeInput(i)) { const ArgDef x = I(i); - const ArgDef mask_cast_i_def = IA("Mask_Cast_" + std::to_string(i)); const ArgDef pre_reduce_grad_i_def = IA("PreReduceGrad_" + std::to_string(i), OType(0)); - result.push_back(NodeDef("Cast", - {IA("Mask_" + std::to_string(i))}, - {mask_cast_i_def}, - {MakeAttribute("to", int64_t(IElemType(0)))})); - result.push_back(NodeDef("Mul", {mask_cast_i_def, GO(0)}, {pre_reduce_grad_i_def})); - if (a.name.compare(b.name) == 0) { - result.push_back(NodeDef("Identity", {pre_reduce_grad_i_def}, {GI(i)})); - continue; - } - + result.emplace_back(NodeDef("Mul", {dy_scaled, IA("Mask_Cast_" + std::to_string(i))}, {pre_reduce_grad_i_def})); std::vector x_shape; - if (get_y_shape_ok && GetShape(x, x_shape).IsOK()) { + if (has_y_shape && GetShape(x, x_shape).IsOK()) { std::vector x_axes; ComputeBroadcastBackwardAxes(x_shape, y_shape, &x_axes, nullptr, NodeName()); - if (x_axes.size() > 0) { + if (!x_axes.empty()) { HandleBroadcasting(pre_reduce_grad_i_def, x, GI(i), x_axes, result); } else { - result.push_back(NodeDef("Identity", {pre_reduce_grad_i_def}, {GI(i)})); + result.emplace_back(NodeDef("Identity", {pre_reduce_grad_i_def}, {GI(i)})); } } else { ArgDef x_axes_def = IA("ReduceAxes_" + x.name); diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index 723017a5f2..1af5b18072 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -2519,10 +2519,12 @@ void GradientCheckerMinMaxGradHelper(const std::string op) { float max_error; GradientChecker gradient_checker; OpDef op_def{op, kOnnxDomain, 11}; - // Ensure the gap between x1 and x2 is greater than 1e-3f, otherwise the result of NumericJacobian + // Ensure the gap between tensors is greater than 1e-3f, otherwise the result of NumericJacobian // will be incorrect. This also excludes equal inputs case, where Min/Max is not smooth. std::function x1_transformer = [](float x) { return (int)(x * 100) / 100.f; }; std::function x2_transformer = [](float x) { return (int)(x * 100) / 100.f + 0.002f; }; + std::function x3_transformer = [](float x) { return (int)(x * 100) / 100.f + 0.004f; }; + std::function x4_transformer = [](float x) { return (int)(x * 100) / 100.f + 0.006f; }; TensorInfo x1_info({2, 3}, true, &x1_transformer); TensorInfo y_info({2, 3}, true); @@ -2543,6 +2545,23 @@ void GradientCheckerMinMaxGradHelper(const std::string op) { ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {x1_info, x2_info}, {y_info}, &max_error)); EXPECT_IS_TINY(max_error); } + + // More than 2 inputs. + { + TensorInfo x2_info({2, 3}, true, &x2_transformer); + TensorInfo x3_info({2, 3}, true, &x3_transformer); + ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {x1_info, x2_info, x3_info}, {y_info}, &max_error)); + EXPECT_IS_TINY(max_error); + } + + { + TensorInfo x2_info({3}, true, &x2_transformer); + TensorInfo x3_info({2, 1}, true, &x3_transformer); + TensorInfo x4_info({2, 3}, true, &x4_transformer); + ASSERT_STATUS_OK( + gradient_checker.ComputeGradientError(op_def, {x1_info, x2_info, x3_info, x4_info}, {y_info}, &max_error)); + EXPECT_IS_TINY(max_error); + } } TEST(GradientCheckerTest, MinGrad) { diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index c3e4b8652b..a85ad8e49e 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -2,44 +2,41 @@ # Licensed under the MIT License. # orttraining_test_ortmodule_api.py +import copy import itertools import math -import random -import copy -import torch -from transformers import AutoConfig, BertForSequenceClassification, Trainer -from transformers.modeling_outputs import SequenceClassifierOutput -import pytest -from time import sleep -import warnings -from unittest.mock import patch -from collections import OrderedDict -from collections import namedtuple -from inspect import signature -import tempfile import os import pickle +import random +import tempfile +import warnings +from collections import OrderedDict, namedtuple from distutils.version import LooseVersion -from onnxruntime.training.ortmodule._custom_gradient_registry import register_gradient -from onnxruntime.training.ortmodule import ( - ORTModule, - _utils, - _io, - DebugOptions, - LogLevel, - _fallback, - _graph_execution_manager, -) -import onnxruntime.training.ortmodule as ortmodule_module - -from onnxruntime.training.optim import FusedAdam, AdamWMode -from transformers import AdamW +from inspect import signature +from time import sleep +from unittest.mock import patch import _test_helpers +import pytest +import torch # Import autocasting libs from torch.cuda import amp +from transformers import AdamW, AutoConfig, BertForSequenceClassification, Trainer +from transformers.modeling_outputs import SequenceClassifierOutput +import onnxruntime.training.ortmodule as ortmodule_module +from onnxruntime.training.optim import AdamWMode, FusedAdam +from onnxruntime.training.ortmodule import ( + DebugOptions, + LogLevel, + ORTModule, + _fallback, + _graph_execution_manager, + _io, + _utils, +) +from onnxruntime.training.ortmodule._custom_gradient_registry import register_gradient DEFAULT_OPSET = 14 @@ -1123,9 +1120,12 @@ def test_gradient_correctness_max(operator, dim, keepdim): _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) -@pytest.mark.skip("temporarily disabled due to PyTorch's change of max implementation") +# Before 1.10 (excluded), Torch's min/max(x,y) will assign dY to y's dX if value from x and y are equal. +# From 1.10, both x and y's dX will be dY/2. ORT follows this distribution logic, so skip below test if Torch version +# is lower than 1.10. +@pytest.mark.skipif(LooseVersion(torch.__version__) < LooseVersion("1.10.0"), reason="PyTorch 1.9 incompatible") @pytest.mark.parametrize("operator", ["min", "max"]) -def test_gradient_correctness_max_two_tensors(operator): +def test_gradient_correctness_minmax_two_tensors(operator): func = getattr(torch, operator) class NeuralNetMaxTwoTensors(torch.nn.Module): @@ -1153,6 +1153,19 @@ def test_gradient_correctness_max_two_tensors(operator): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) + _test_helpers.assert_values_are_close(ort_other.grad, pt_other.grad) + + # Simple test for case that has equal value. + pt_input = torch.tensor([0.0, 0.0, 1.0, 1.0], device=device, requires_grad=True) + pt_other = torch.tensor([1.0, 0.0, 1.0, 0.0], device=device, requires_grad=True) + ort_input = copy.deepcopy(pt_input) + ort_other = copy.deepcopy(pt_other) + pt_prediction = run_step(pt_model, pt_input, pt_other) + ort_prediction = run_step(ort_model, ort_input, ort_other) + + _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) + _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) + _test_helpers.assert_values_are_close(ort_other.grad, pt_other.grad) def test_gradient_correctness_argmax_unfold(): @@ -2711,7 +2724,7 @@ def test_forward_data_and_model_on_different_devices(data_device, model_device): # Now that the model has been exported, feed in data from device other than the model device x = torch.randn(N, D_in, device=data_device) - from onnxruntime.training.ortmodule._fallback import _FallbackPolicy, ORTModuleDeviceException + from onnxruntime.training.ortmodule._fallback import ORTModuleDeviceException, _FallbackPolicy if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE): # Fallback