mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
Change MinGrad/MaxGrad to Use Distributed Logic (#11388)
* change min max grad * resolve comments
This commit is contained in:
parent
860ba8820b
commit
084165c748
3 changed files with 92 additions and 63 deletions
|
|
@ -1044,7 +1044,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceMeanGradient) {
|
|||
result.push_back(axes_values_node);
|
||||
result.push_back(NodeDef(OpDef{"Unsqueeze", kOnnxDomain, 13}, {GO(0), axes_values_node.output_args[0]}, {grad}));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result.push_back(NodeDef("Size", {I(0)}, {IA("Sized_X")}));
|
||||
|
|
@ -1597,55 +1597,52 @@ IMPLEMENT_GRADIENT_BUILDER(GetTileGradient) {
|
|||
}
|
||||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetMinMaxGradient) {
|
||||
const auto num_src_node_inputs = GetSrcNodeInputSize();
|
||||
if (num_src_node_inputs == 1) {
|
||||
if (IsGradientRequiredForSrcNodeInput(0)) {
|
||||
return std::vector<NodeDef>{NodeDef("Identity", {GO(0)}, {GI(0)})};
|
||||
const size_t num_src_node_inputs = static_cast<size_t>(GetSrcNodeInputSize());
|
||||
bool has_gradient_required = false;
|
||||
for (size_t i = 0; i < num_src_node_inputs; ++i) {
|
||||
if (IsGradientRequiredForSrcNodeInput(i)) {
|
||||
has_gradient_required = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!has_gradient_required) {
|
||||
return std::vector<NodeDef>{};
|
||||
}
|
||||
|
||||
if (num_src_node_inputs > 2) {
|
||||
ORT_THROW("Min/Max gradient currently does not support over 2 inputs.");
|
||||
}
|
||||
|
||||
if (!IsGradientRequiredForSrcNodeInput(0) && !IsGradientRequiredForSrcNodeInput(1)) {
|
||||
return std::vector<NodeDef>{};
|
||||
if (num_src_node_inputs == 1) {
|
||||
return std::vector<NodeDef>{NodeDef("Identity", {GO(0)}, {GI(0)})};
|
||||
}
|
||||
|
||||
std::vector<NodeDef> result;
|
||||
std::vector<Dimension> y_shape;
|
||||
const ArgDef y = O(0);
|
||||
bool get_y_shape_ok = GetShape(y, y_shape).IsOK();
|
||||
result.push_back(NodeDef("Equal", {I(1), y}, {IA("Mask_1")}));
|
||||
if (IsGradientRequiredForSrcNodeInput(0)) {
|
||||
result.push_back(NodeDef("Not", {IA("Mask_1")}, {IA("Mask_0")}));
|
||||
std::vector<ArgDef> sum_inputs;
|
||||
for (size_t i = 0; i < num_src_node_inputs; ++i) {
|
||||
const ArgDef mask = IA("Mask_" + std::to_string(i));
|
||||
const ArgDef mask_cast = IA("Mask_Cast_" + std::to_string(i));
|
||||
result.emplace_back(NodeDef("Equal", {I(i), y}, {mask}));
|
||||
result.emplace_back(NodeDef("Cast", {mask}, {mask_cast}, {MakeAttribute("to", int64_t(IElemType(0)))}));
|
||||
sum_inputs.emplace_back(mask_cast);
|
||||
}
|
||||
const ArgDef a = I(0), b = I(1);
|
||||
for (int i = 0; i < num_src_node_inputs; i++) {
|
||||
|
||||
const ArgDef dy_scaled = IA("dY_Scaled");
|
||||
result.emplace_back(NodeDef("Sum", sum_inputs, {IA("Scale")}));
|
||||
result.emplace_back(NodeDef("Div", {GO(0), IA("Scale")}, {dy_scaled}));
|
||||
std::vector<Dimension> y_shape;
|
||||
bool has_y_shape = GetShape(y, y_shape).IsOK();
|
||||
for (size_t i = 0; i < num_src_node_inputs; ++i) {
|
||||
if (IsGradientRequiredForSrcNodeInput(i)) {
|
||||
const ArgDef x = I(i);
|
||||
const ArgDef mask_cast_i_def = IA("Mask_Cast_" + std::to_string(i));
|
||||
const ArgDef pre_reduce_grad_i_def = IA("PreReduceGrad_" + std::to_string(i), OType(0));
|
||||
result.push_back(NodeDef("Cast",
|
||||
{IA("Mask_" + std::to_string(i))},
|
||||
{mask_cast_i_def},
|
||||
{MakeAttribute("to", int64_t(IElemType(0)))}));
|
||||
result.push_back(NodeDef("Mul", {mask_cast_i_def, GO(0)}, {pre_reduce_grad_i_def}));
|
||||
if (a.name.compare(b.name) == 0) {
|
||||
result.push_back(NodeDef("Identity", {pre_reduce_grad_i_def}, {GI(i)}));
|
||||
continue;
|
||||
}
|
||||
|
||||
result.emplace_back(NodeDef("Mul", {dy_scaled, IA("Mask_Cast_" + std::to_string(i))}, {pre_reduce_grad_i_def}));
|
||||
std::vector<Dimension> x_shape;
|
||||
if (get_y_shape_ok && GetShape(x, x_shape).IsOK()) {
|
||||
if (has_y_shape && GetShape(x, x_shape).IsOK()) {
|
||||
std::vector<int64_t> x_axes;
|
||||
ComputeBroadcastBackwardAxes(x_shape, y_shape, &x_axes, nullptr, NodeName());
|
||||
if (x_axes.size() > 0) {
|
||||
if (!x_axes.empty()) {
|
||||
HandleBroadcasting(pre_reduce_grad_i_def, x, GI(i), x_axes, result);
|
||||
} else {
|
||||
result.push_back(NodeDef("Identity", {pre_reduce_grad_i_def}, {GI(i)}));
|
||||
result.emplace_back(NodeDef("Identity", {pre_reduce_grad_i_def}, {GI(i)}));
|
||||
}
|
||||
} else {
|
||||
ArgDef x_axes_def = IA("ReduceAxes_" + x.name);
|
||||
|
|
|
|||
|
|
@ -2519,10 +2519,12 @@ void GradientCheckerMinMaxGradHelper(const std::string op) {
|
|||
float max_error;
|
||||
GradientChecker<float, float, float> gradient_checker;
|
||||
OpDef op_def{op, kOnnxDomain, 11};
|
||||
// Ensure the gap between x1 and x2 is greater than 1e-3f, otherwise the result of NumericJacobian
|
||||
// Ensure the gap between tensors is greater than 1e-3f, otherwise the result of NumericJacobian
|
||||
// will be incorrect. This also excludes equal inputs case, where Min/Max is not smooth.
|
||||
std::function<float(float)> x1_transformer = [](float x) { return (int)(x * 100) / 100.f; };
|
||||
std::function<float(float)> x2_transformer = [](float x) { return (int)(x * 100) / 100.f + 0.002f; };
|
||||
std::function<float(float)> x3_transformer = [](float x) { return (int)(x * 100) / 100.f + 0.004f; };
|
||||
std::function<float(float)> x4_transformer = [](float x) { return (int)(x * 100) / 100.f + 0.006f; };
|
||||
TensorInfo x1_info({2, 3}, true, &x1_transformer);
|
||||
TensorInfo y_info({2, 3}, true);
|
||||
|
||||
|
|
@ -2543,6 +2545,23 @@ void GradientCheckerMinMaxGradHelper(const std::string op) {
|
|||
ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {x1_info, x2_info}, {y_info}, &max_error));
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
// More than 2 inputs.
|
||||
{
|
||||
TensorInfo x2_info({2, 3}, true, &x2_transformer);
|
||||
TensorInfo x3_info({2, 3}, true, &x3_transformer);
|
||||
ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {x1_info, x2_info, x3_info}, {y_info}, &max_error));
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
{
|
||||
TensorInfo x2_info({3}, true, &x2_transformer);
|
||||
TensorInfo x3_info({2, 1}, true, &x3_transformer);
|
||||
TensorInfo x4_info({2, 3}, true, &x4_transformer);
|
||||
ASSERT_STATUS_OK(
|
||||
gradient_checker.ComputeGradientError(op_def, {x1_info, x2_info, x3_info, x4_info}, {y_info}, &max_error));
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, MinGrad) {
|
||||
|
|
|
|||
|
|
@ -2,44 +2,41 @@
|
|||
# Licensed under the MIT License.
|
||||
# orttraining_test_ortmodule_api.py
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
import math
|
||||
import random
|
||||
import copy
|
||||
import torch
|
||||
from transformers import AutoConfig, BertForSequenceClassification, Trainer
|
||||
from transformers.modeling_outputs import SequenceClassifierOutput
|
||||
import pytest
|
||||
from time import sleep
|
||||
import warnings
|
||||
from unittest.mock import patch
|
||||
from collections import OrderedDict
|
||||
from collections import namedtuple
|
||||
from inspect import signature
|
||||
import tempfile
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
import tempfile
|
||||
import warnings
|
||||
from collections import OrderedDict, namedtuple
|
||||
from distutils.version import LooseVersion
|
||||
from onnxruntime.training.ortmodule._custom_gradient_registry import register_gradient
|
||||
from onnxruntime.training.ortmodule import (
|
||||
ORTModule,
|
||||
_utils,
|
||||
_io,
|
||||
DebugOptions,
|
||||
LogLevel,
|
||||
_fallback,
|
||||
_graph_execution_manager,
|
||||
)
|
||||
import onnxruntime.training.ortmodule as ortmodule_module
|
||||
|
||||
from onnxruntime.training.optim import FusedAdam, AdamWMode
|
||||
from transformers import AdamW
|
||||
from inspect import signature
|
||||
from time import sleep
|
||||
from unittest.mock import patch
|
||||
|
||||
import _test_helpers
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Import autocasting libs
|
||||
from torch.cuda import amp
|
||||
from transformers import AdamW, AutoConfig, BertForSequenceClassification, Trainer
|
||||
from transformers.modeling_outputs import SequenceClassifierOutput
|
||||
|
||||
import onnxruntime.training.ortmodule as ortmodule_module
|
||||
from onnxruntime.training.optim import AdamWMode, FusedAdam
|
||||
from onnxruntime.training.ortmodule import (
|
||||
DebugOptions,
|
||||
LogLevel,
|
||||
ORTModule,
|
||||
_fallback,
|
||||
_graph_execution_manager,
|
||||
_io,
|
||||
_utils,
|
||||
)
|
||||
from onnxruntime.training.ortmodule._custom_gradient_registry import register_gradient
|
||||
|
||||
DEFAULT_OPSET = 14
|
||||
|
||||
|
|
@ -1123,9 +1120,12 @@ def test_gradient_correctness_max(operator, dim, keepdim):
|
|||
_test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad)
|
||||
|
||||
|
||||
@pytest.mark.skip("temporarily disabled due to PyTorch's change of max implementation")
|
||||
# Before 1.10 (excluded), Torch's min/max(x,y) will assign dY to y's dX if value from x and y are equal.
|
||||
# From 1.10, both x and y's dX will be dY/2. ORT follows this distribution logic, so skip below test if Torch version
|
||||
# is lower than 1.10.
|
||||
@pytest.mark.skipif(LooseVersion(torch.__version__) < LooseVersion("1.10.0"), reason="PyTorch 1.9 incompatible")
|
||||
@pytest.mark.parametrize("operator", ["min", "max"])
|
||||
def test_gradient_correctness_max_two_tensors(operator):
|
||||
def test_gradient_correctness_minmax_two_tensors(operator):
|
||||
func = getattr(torch, operator)
|
||||
|
||||
class NeuralNetMaxTwoTensors(torch.nn.Module):
|
||||
|
|
@ -1153,6 +1153,19 @@ def test_gradient_correctness_max_two_tensors(operator):
|
|||
|
||||
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
|
||||
_test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad)
|
||||
_test_helpers.assert_values_are_close(ort_other.grad, pt_other.grad)
|
||||
|
||||
# Simple test for case that has equal value.
|
||||
pt_input = torch.tensor([0.0, 0.0, 1.0, 1.0], device=device, requires_grad=True)
|
||||
pt_other = torch.tensor([1.0, 0.0, 1.0, 0.0], device=device, requires_grad=True)
|
||||
ort_input = copy.deepcopy(pt_input)
|
||||
ort_other = copy.deepcopy(pt_other)
|
||||
pt_prediction = run_step(pt_model, pt_input, pt_other)
|
||||
ort_prediction = run_step(ort_model, ort_input, ort_other)
|
||||
|
||||
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
|
||||
_test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad)
|
||||
_test_helpers.assert_values_are_close(ort_other.grad, pt_other.grad)
|
||||
|
||||
|
||||
def test_gradient_correctness_argmax_unfold():
|
||||
|
|
@ -2711,7 +2724,7 @@ def test_forward_data_and_model_on_different_devices(data_device, model_device):
|
|||
|
||||
# Now that the model has been exported, feed in data from device other than the model device
|
||||
x = torch.randn(N, D_in, device=data_device)
|
||||
from onnxruntime.training.ortmodule._fallback import _FallbackPolicy, ORTModuleDeviceException
|
||||
from onnxruntime.training.ortmodule._fallback import ORTModuleDeviceException, _FallbackPolicy
|
||||
|
||||
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE):
|
||||
# Fallback
|
||||
|
|
|
|||
Loading…
Reference in a new issue