From 0e59668c1b8974afa9ec8e09e2da7b47dcbbb86b Mon Sep 17 00:00:00 2001 From: Xueyun Zhu <40807589+xzhu1900@users.noreply.github.com> Date: Wed, 6 May 2020 10:40:57 -0700 Subject: [PATCH] add support for symbolic broadcast for Add/Sub/Mul (#3743) * add support for symbolic broadcast * fix comment * address feedback --- .../test/providers/provider_test_utils.h | 37 +++++++++++++++---- .../core/graph/gradient_builder_base.cc | 26 ++++++++++++- .../test/gradient/gradient_checker.cc | 31 ++++++++++++---- .../test/gradient/gradient_checker.h | 12 ++++-- .../test/gradient/gradient_op_test_utils.cc | 2 +- .../test/gradient/gradient_ops_test.cc | 21 +++++++++++ 6 files changed, 108 insertions(+), 21 deletions(-) diff --git a/onnxruntime/test/providers/provider_test_utils.h b/onnxruntime/test/providers/provider_test_utils.h index abb87684d6..50befb3191 100644 --- a/onnxruntime/test/providers/provider_test_utils.h +++ b/onnxruntime/test/providers/provider_test_utils.h @@ -250,19 +250,19 @@ class OpTester { // bool and we can't get the raw data out. So those cases must use an initializer_list template void AddInput(const char* name, const std::vector& dims, const std::initializer_list& values, - bool is_initializer = false) { - AddData(input_data_, name, dims, values.begin(), values.size(), is_initializer); + bool is_initializer = false, const std::vector* dim_params = nullptr) { + AddData(input_data_, name, dims, values.begin(), values.size(), is_initializer, false, dim_params); } template void AddInput(const char* name, const std::vector& dims, const std::vector& values, - bool is_initializer = false) { - AddData(input_data_, name, dims, values.data(), values.size(), is_initializer); + bool is_initializer = false, const std::vector* dim_params = nullptr) { + AddData(input_data_, name, dims, values.data(), values.size(), is_initializer, false, dim_params); } template - void AddInput(const char* name, const std::vector& dims, const T* p_values, const size_t size, bool is_initializer = false) { - AddData(input_data_, name, dims, p_values, size, is_initializer); + void AddInput(const char* name, const std::vector& dims, const T* p_values, const size_t size, bool is_initializer = false, const std::vector* dim_params = nullptr) { + AddData(input_data_, name, dims, p_values, size, is_initializer, false, dim_params); } // Add other registered types, possibly experimental @@ -505,7 +505,8 @@ class OpTester { protected: template void AddData(std::vector& data, const char* name, const std::vector& dims, const T* values, - int64_t values_count, bool is_initializer = false, bool sort_output = false) { + int64_t values_count, bool is_initializer = false, bool sort_output = false, + const std::vector* dim_params = nullptr) { try { TensorShape shape{dims}; ORT_ENFORCE(shape.Size() == values_count, values_count, " input values doesn't match tensor size of ", @@ -529,7 +530,27 @@ class OpTester { OrtValue value; value.Init(p_tensor.release(), DataTypeImpl::GetType(), DataTypeImpl::GetType()->GetDeleteFunc()); - data.push_back(Data(NodeArg(name, &type_proto), std::move(value), optional(), optional(), sort_output)); + auto node_arg = NodeArg(name, &type_proto); + if (dim_params && !(dim_params->empty())) { + // If dim_params presents, configure node_arg's dim value based on dim_params, which supports symbolic dim and dim broadcast. + auto& dim_params_data = *dim_params; + onnx::TensorShapeProto new_shape; + + // currently hard-code the reserved symbolic names. + // TODO: when the list grows longer, consider move it to a better place. + const static std::unordered_set reserved_symbolic{"batch", "seq"}; + + for (size_t i = 0; i < dim_params_data.size(); ++i) { + if (reserved_symbolic.find(dim_params_data[i])!= reserved_symbolic.end()) { + new_shape.add_dim()->set_dim_param(dim_params_data[i]); + } else { + ASSERT_TRUE(std::stoi(dim_params_data[i]) == dims[i]); + new_shape.add_dim()->set_dim_value(dims[i]); + } + } + node_arg.SetShape(new_shape); + } + data.push_back(Data(std::move(node_arg), std::move(value), optional(), optional(), sort_output)); if (is_initializer) initializer_index_.push_back(data.size() - 1); } catch (const std::exception& ex) { std::cerr << "AddData for '" << name << "' threw: " << ex.what(); diff --git a/orttraining/orttraining/core/graph/gradient_builder_base.cc b/orttraining/orttraining/core/graph/gradient_builder_base.cc index 3a63f2f8fa..207f10850b 100644 --- a/orttraining/orttraining/core/graph/gradient_builder_base.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_base.cc @@ -39,10 +39,32 @@ void ComputeBroadcastBackwardAxes( auto A_dim = A_dims[i].dim_param(), B_dim = B_dims[j].dim_param(); if (A_dim != B_dim) { - ORT_THROW("Error"); + ORT_THROW("Error: symbolic dimension doesn't match. Expect the same symbolic but got \"", + A_dim, "\" and \"", B_dim, "\"."); + } + } else if (A_dims[i].has_dim_param() && B_dims[j].has_dim_value()) { + auto A_dim = A_dims[i].dim_param(); + auto B_dim = B_dims[j].dim_value(); + + if (B_dim != 1) { + ORT_THROW("Error: symbolic broadcasting requires the corresponding dimension to be 1. ", + "Actually got ", B_dim); + } + if (B_axes) { + B_axes->push_back(gsl::narrow_cast(k)); + } + } else if (A_dims[i].has_dim_value() && B_dims[j].has_dim_param()) { + auto A_dim = A_dims[j].dim_value(); + auto B_dim = B_dims[i].dim_param(); + + if (A_dim != 1) { + ORT_THROW("Error: symbolic broadcasting requires the corresponding dimension to be 1. ", + "Actually got ", A_dim); + } + if (A_axes) { + A_axes->push_back(gsl::narrow_cast(k)); } } - // TODO : complete othere cases --i; --j; diff --git a/orttraining/orttraining/test/gradient/gradient_checker.cc b/orttraining/orttraining/test/gradient/gradient_checker.cc index 4fc263a461..e923b81926 100644 --- a/orttraining/orttraining/test/gradient/gradient_checker.cc +++ b/orttraining/orttraining/test/gradient/gradient_checker.cc @@ -129,7 +129,7 @@ inline Status GradientChecker::ComputeTheoreticalJacobianTransp // Compute the theoretical Jacobians one row at a time by back propagating // '1.0' for each element of 'dy', while holding all other elements of 'dy' at zero. - for (int c = 0; c < dy_size; ++c) { // for each value in the dy input vector + for (size_t c = 0; c < dy_size; ++c) { // for each value in the dy input vector // clear OpTester input/output/initializer op_session.ClearData(); @@ -167,7 +167,7 @@ inline Status GradientChecker::ComputeTheoreticalJacobianTransp // inputs is treated as a vector of vectors. The parameters of the function call below, y_idx and c // corresponding to which input (dy1, dy2..etc) and which value of the input (dy_flattened_vector[c]] // to pertrub to 1. - op_session.Run(y_idx, c); + op_session.Run(y_idx, static_cast(c)); auto gradients = op_session.GetFetches(); for (int x_idx = 0, grad_idx = 0; x_idx < static_cast(x_num); x_idx++) { @@ -186,7 +186,7 @@ inline Status GradientChecker::ComputeTheoreticalJacobianTransp r, y_infos, y_idx, - c); + static_cast(c)); (*jacobian_ts)[calc_index.first][calc_index.second] = dx_flat[r]; } } @@ -211,19 +211,36 @@ inline Status GradientChecker::InitOpTesterWithGraph( if (x_infos[data_index].data_type == DataTypeImpl::GetTensorType()) { std::vector int64_data(data.size()); std::transform(data.begin(), data.end(), int64_data.begin(), [](X_T x) { return static_cast(x); }); - op_session.AddInput(name.c_str(), x_infos[data_index].shape.GetDims(), int64_data); + op_session.AddInput(name.c_str(), + x_infos[data_index].shape.GetDims(), + int64_data, + false, + &x_infos[data_index].dim_params); } else if (x_infos[data_index].data_type == DataTypeImpl::GetTensorType()) { std::vector int32_data(data.size()); std::transform(data.begin(), data.end(), int32_data.begin(), [](X_T x) { return static_cast(x); }); - op_session.AddInput(name.c_str(), x_infos[data_index].shape.GetDims(), int32_data); + op_session.AddInput(name.c_str(), + x_infos[data_index].shape.GetDims(), + int32_data, + false, + &x_infos[data_index].dim_params); } else if (x_infos[data_index].data_type == DataTypeImpl::GetTensorType()) { std::unique_ptr p_data(new bool[data.size()]); for (size_t i = 0; i < data.size(); ++i) { p_data[i] = static_cast(data[i]); } - op_session.AddInput(name.c_str(), x_infos[data_index].shape.GetDims(), p_data.get(), data.size()); + op_session.AddInput(name.c_str(), + x_infos[data_index].shape.GetDims(), + p_data.get(), + data.size(), + false, + &x_infos[data_index].dim_params); } else { - op_session.AddInput(name.c_str(), x_infos[data_index].shape.GetDims(), data); + op_session.AddInput(name.c_str(), + x_infos[data_index].shape.GetDims(), + data, + false, + &x_infos[data_index].dim_params); } } diff --git a/orttraining/orttraining/test/gradient/gradient_checker.h b/orttraining/orttraining/test/gradient/gradient_checker.h index 3a81f711b7..bd17469228 100644 --- a/orttraining/orttraining/test/gradient/gradient_checker.h +++ b/orttraining/orttraining/test/gradient/gradient_checker.h @@ -26,8 +26,13 @@ struct TensorInfo { TensorInfo(const std::initializer_list& shape, bool has_gradient = true, std::function* transformer = nullptr, - MLDataType data_type = DataTypeImpl::GetTensorType()) - : shape(shape), has_gradient(has_gradient), transformer(transformer), data_type(data_type) {} + MLDataType data_type = DataTypeImpl::GetTensorType(), + const std::vector& dim_params = std::vector{}) + : shape(shape), + has_gradient(has_gradient), + transformer(transformer), + data_type(data_type), + dim_params(dim_params) {} TensorInfo(const TensorShape& shape, bool has_gradient = true, @@ -39,6 +44,7 @@ struct TensorInfo { bool has_gradient; std::function* transformer; MLDataType data_type; + std::vector dim_params; }; // TODO: This class currently assumes the inputs share types and the outputs share a type. @@ -85,7 +91,7 @@ class GradientChecker { const std::vector& y_infos, std::vector>* jacobians); - std::vector EvaluateFunctionAtInput(OpTester& op_tester, + std::vector EvaluateFunctionAtInput(OpTester& op_tester, const std::vector& x_infos, const std::vector& y_infos, std::vector>* x_datas, diff --git a/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc b/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc index 7bc673bc9d..c8d546dba3 100644 --- a/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc +++ b/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc @@ -209,7 +209,7 @@ void GradientOpTester::FillFeedsAndOutputNames(std::unordered_map().Shape(); std::vector values(shape.Size(), 0.0); - if (output_index_to_use_as_loss == i) { + if (output_index_to_use_as_loss == static_cast(i)) { values[data_index_of_output] = 1.0; //set only one value to one to construct jacobian matrix } AddData(gradient_data, (output_data_[i].def_.Name() + "_grad").c_str(), shape.GetDims(), values.data(), values.size(), true); diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index a407d07e7d..5b22f22b7f 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -185,6 +185,27 @@ void TestBroadcastableBinaryOpGrad(const std::string& op_type, gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error); EXPECT_IS_TINY(max_error); } + + // symbolic broadcast + // shape(A) = (4, 2, 1, "seq(3)"), shape(B) = (4, 2, 1, 1), ==> shape(result) = (4, 2, 1, 3) + { + TensorInfo A_info{{4, 2, 1, 3}, true, transformer, DataTypeImpl::GetTensorType(), {"4", "2", "1", "seq"}}; + TensorInfo B_info{{4, 2, 1, 1}, true, transformer, DataTypeImpl::GetTensorType(), {"4", "2", "1", "1"}}; + TensorInfo Y_info{{4, 2, 1, 3}}; + + gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error); + EXPECT_IS_TINY(max_error); + } + // symbolic broadcast + numeric broadcast + // shape(A) = ("batch(4)", 2, "seq(3)", "seq(3)"), shape(B) = ("batch(4)", 1, "seq(3)", "seq(3)"), ==> shape(result) = (4, 2, 3, 3) + { + TensorInfo A_info{{4, 2, 3, 3}, true, transformer, DataTypeImpl::GetTensorType(), {"batch", "2", "seq", "seq"}}; + TensorInfo B_info{{4, 1, 1, 3}, true, transformer, DataTypeImpl::GetTensorType(), {"batch", "1", "1", "seq"}}; + TensorInfo Y_info{{4, 2, 3, 3}}; + + gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error); + EXPECT_IS_TINY(max_error); + } } TEST(GradientCheckerTest, AddGrad) {