From c23b484275794aa4826b5495202746bba3cc7a3d Mon Sep 17 00:00:00 2001 From: George Wu Date: Sun, 26 Apr 2020 17:59:36 -0700 Subject: [PATCH 1/8] add missing deps in Dockerfile.openvino --- dockerfiles/Dockerfile.openvino | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dockerfiles/Dockerfile.openvino b/dockerfiles/Dockerfile.openvino index 7f1238818c..958d31bc15 100644 --- a/dockerfiles/Dockerfile.openvino +++ b/dockerfiles/Dockerfile.openvino @@ -7,7 +7,7 @@ FROM ubuntu:16.04 RUN apt update && \ apt -y install git sudo wget \ - zip x11-apps lsb-core cpio libboost-python-dev libpng-dev zlib1g-dev libnuma1 ocl-icd-libopencl1 clinfo libboost-filesystem1.58.0 libboost-thread1.58.0 protobuf-compiler libprotoc-dev libusb-1.0-0-dev + zip x11-apps lsb-core cpio libboost-python-dev libpng-dev zlib1g-dev libnuma1 ocl-icd-libopencl1 clinfo libboost-filesystem1.58.0 libboost-thread1.58.0 protobuf-compiler libprotoc-dev libusb-1.0-0-dev autoconf automake libtool ARG DEVICE=CPU_FP32 ARG ONNXRUNTIME_REPO=https://github.com/microsoft/onnxruntime From d9016408173171b0c439df93b0d05ee3e6a46596 Mon Sep 17 00:00:00 2001 From: Prabhat Date: Mon, 27 Apr 2020 17:41:33 +0530 Subject: [PATCH 2/8] Call optimised version of depthwise ConvLayer (#3664) * Call optimised version of depthwise ConvLayer * Update if statements --- onnxruntime/core/providers/acl/nn/conv.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/acl/nn/conv.cc b/onnxruntime/core/providers/acl/nn/conv.cc index f95d6cd6ef..c7ccaf5f04 100644 --- a/onnxruntime/core/providers/acl/nn/conv.cc +++ b/onnxruntime/core/providers/acl/nn/conv.cc @@ -208,7 +208,7 @@ Status Conv::Compute(OpKernelContext* context) const { if(optimizable) { //optimized depthwise convolution #if defined(ACL_1902) || defined(ACL_1905) - auto layer = std::make_shared(); + auto layer = std::make_shared(); #endif #ifdef ACL_1908 auto layer = std::make_shared(); From 635bc9cd0492504b86883327b62a7b12d22181ba Mon Sep 17 00:00:00 2001 From: Sherlock Date: Mon, 27 Apr 2020 11:53:45 -0700 Subject: [PATCH 3/8] Fix graph transformers to support opset 12 ops (#3715) --- onnxruntime/core/optimizer/fast_gelu_fusion.cc | 2 +- onnxruntime/core/optimizer/layer_norm_fusion.cc | 4 ++-- .../orttraining/core/optimizer/insert_output_rewriter.cc | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/optimizer/fast_gelu_fusion.cc b/onnxruntime/core/optimizer/fast_gelu_fusion.cc index ca122f79ba..c71555bb83 100644 --- a/onnxruntime/core/optimizer/fast_gelu_fusion.cc +++ b/onnxruntime/core/optimizer/fast_gelu_fusion.cc @@ -108,7 +108,7 @@ MatchResult FastGeluFusion::CheckFirstFormula(Graph& graph, Node& mul1_node, MatchResult FastGeluFusion::CheckSecondFormula(Graph& graph, Node& pow1_node, std::vector>& nodes_to_fuse) const { MatchResult matchResult{false, nullptr, nullptr}; - if (!graph_utils::IsSupportedOptypeVersionAndDomain(pow1_node, "Pow", {7}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(pow1_node, "Pow", {7, 12}) || !graph_utils::IsSupportedProvider(pow1_node, GetCompatibleExecutionProviders()) || pow1_node.GetOutputEdgesCount() != 1 || !IsSupportedDataType(pow1_node)) { diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index 327639f425..dddb62bf80 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -169,10 +169,10 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, continue; } nodes_to_remove.push_back(reduce_mean2_node); - + // Traceback the reduceMean node to find pow --> reduceMean Node& pow_node = *graph.GetNode(reduce_mean2_node.InputNodesBegin()->Index()); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(pow_node, "Pow", {7}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(pow_node, "Pow", {7, 12}) || pow_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() || pow_node.GetOutputEdgesCount() != 1 || !IsSupportedDataType(pow_node)) { diff --git a/orttraining/orttraining/core/optimizer/insert_output_rewriter.cc b/orttraining/orttraining/core/optimizer/insert_output_rewriter.cc index b1ef4b18e5..ceff7c3f77 100644 --- a/orttraining/orttraining/core/optimizer/insert_output_rewriter.cc +++ b/orttraining/orttraining/core/optimizer/insert_output_rewriter.cc @@ -25,7 +25,7 @@ Status InsertMaxPoolOutput::Apply(Graph& graph, Node& node, RewriteRuleEffect& r } bool InsertMaxPoolOutput::SatisfyCondition(const Graph& /*graph*/, const Node& node, const logging::Logger& /*logger*/) const { - if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "MaxPool", {8}) && + if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "MaxPool", {8, 10, 11, 12}) && node.OutputDefs().size() == 1) { return true; } From 7627e6bcc26893a96344ad90fabb11cd76530348 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Mon, 27 Apr 2020 13:57:24 -0700 Subject: [PATCH 4/8] Improve node and node argument name generation (#3649) --- include/onnxruntime/core/graph/graph.h | 8 +++ onnxruntime/core/graph/graph.cc | 53 ++++++++++++++----- onnxruntime/core/providers/cpu/math/clip.cc | 12 +++++ onnxruntime/core/providers/cpu/math/clip.h | 11 +--- .../src/GraphDescBuilder.cpp | 21 ++++++-- 5 files changed, 79 insertions(+), 26 deletions(-) diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index f08df584ba..12245aeb51 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -1106,6 +1106,14 @@ class Graph { // Graph value_info. std::vector value_info_; + // Strings which have been used as node names. + // New node name should not conflict with this set. + std::unordered_set generated_node_names_; + + // Strings which have been used as node_arg names. + // New node_arg name should not conflict this this set. + std::unordered_set generated_node_arg_names_; + // All node args owned by <*this> graph. Key is node arg name. std::unordered_map> node_args_; diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index a5e4c3c73f..d939d25876 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -2378,28 +2378,57 @@ Node& Graph::AddNode(const NodeProto& node_proto, } std::string Graph::GenerateNodeArgName(const std::string& base_name) { - std::string new_name; - do { + std::string new_name = base_name; + // Check if new_name has been used in as any of node_args_' names. + // Check if new_name has been generated by this function. + // If both are not, add new_name into name set and return the new_name + // as the generated name. Otherwise, keep generating new names. + while (node_args_.find(new_name) != node_args_.end() || + generated_node_arg_names_.find(new_name) != generated_node_arg_names_.end()) { std::ostringstream str; - str << base_name << "_" << name_generator_++; + str << base_name << "_token_" << name_generator_++; new_name = str.str(); - } while (node_args_.find(new_name) != node_args_.end()); + } + + generated_node_arg_names_.insert(new_name); return new_name; } std::string Graph::GenerateNodeName(const std::string& base_name) { - std::string new_name; - bool keep_going = true; + // Define name-checking function for node name. + // Return true if the input name hasn't been used. Otherwise, return false. + auto name_is_ok = [&] (const std::string name) { + for (auto it = nodes_.begin(); it != nodes_.end(); ++it) { + if (*it == nullptr) { + continue; + } + if (it->get()->Name() != name) { + continue; + } + // Find a matched name so we cannot reuse the input name. + return false; + } - do { + if (generated_node_names_.find(name) != generated_node_names_.end()) { + // Find a matched name so we cannot reuse the input name. + return false; + } + + // The input name can be reused. + return true; + }; + + // Start with the input name. + std::string new_name = base_name; + + while (!name_is_ok(new_name)) { std::ostringstream str; - str << base_name << "_" << name_generator_++; + str << base_name << "_token_" << name_generator_++; new_name = str.str(); + } - keep_going = std::find_if(nodes_.cbegin(), nodes_.cend(), [&new_name](const std::unique_ptr& n) { - return (n != nullptr) && (n->Name() == new_name); - }) != nodes_.end(); - } while (keep_going); + // Make sure this new_name is not going to be reused. + generated_node_names_.insert(new_name); return new_name; } diff --git a/onnxruntime/core/providers/cpu/math/clip.cc b/onnxruntime/core/providers/cpu/math/clip.cc index 75ae6d4cdd..92c7590d53 100644 --- a/onnxruntime/core/providers/cpu/math/clip.cc +++ b/onnxruntime/core/providers/cpu/math/clip.cc @@ -3,6 +3,7 @@ #include "core/providers/cpu/math/clip.h" #include "core/framework/data_types_internal.h" +#include "core/util/math_cpuonly.h" namespace onnxruntime { @@ -31,6 +32,17 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( REG_KERNEL_NONTEMPL(Clip, 12, Clip, float, double, int8_t, uint8_t, int64_t, uint64_t); +template +Status Clip_6::Compute(OpKernelContext* ctx) const { + const auto* X = ctx->Input(0); + Tensor* Y = ctx->Output(0, X->Shape()); + EigenVectorMap(Y->template MutableData(), Y->Shape().Size()) = + ConstEigenVectorMap(X->template Data(), X->Shape().Size()) + .cwiseMax(this->min_) + .cwiseMin(this->max_); + return Status::OK(); +} + template struct Clip::ComputeImpl { void operator()(const Tensor* X, const Tensor* min, const Tensor* max, Tensor* Y) const { diff --git a/onnxruntime/core/providers/cpu/math/clip.h b/onnxruntime/core/providers/cpu/math/clip.h index b365aa33fd..b7c6032c6b 100644 --- a/onnxruntime/core/providers/cpu/math/clip.h +++ b/onnxruntime/core/providers/cpu/math/clip.h @@ -5,7 +5,6 @@ #include "core/common/common.h" #include "core/framework/op_kernel.h" -#include "core/util/math_cpuonly.h" namespace onnxruntime { @@ -33,15 +32,7 @@ class Clip_6 final : public clip_internal::Clip_6Base, public OpKernel { explicit Clip_6(const OpKernelInfo& info) : clip_internal::Clip_6Base(info), OpKernel(info) { } - Status Compute(OpKernelContext* ctx) const override { - const auto* X = ctx->Input(0); - Tensor* Y = ctx->Output(0, X->Shape()); - EigenVectorMap(Y->template MutableData(), Y->Shape().Size()) = - ConstEigenVectorMap(X->template Data(), X->Shape().Size()) - .cwiseMax(this->min_) - .cwiseMin(this->max_); - return Status::OK(); - } + Status Compute(OpKernelContext* ctx) const override; }; // Since version 11. Min and Max are inputs diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index 622b7b96cb..b07eed9fd8 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -14,8 +14,21 @@ namespace Dml::GraphDescBuilder // mismatch is fixed (WindowsAI: 21114358, Lotus: 1953), this workaround should be removed. static std::string GetFusedNodeArgNameMatchingGraph(const std::string& fusedNodeArgeName) { - // The suffix used when inserting mem copies is equal to the below, followed by an incrementing number. - const char* suffix = strstr(fusedNodeArgeName.c_str(), "_DmlExecutionProvider_"); + const char* suffix = nullptr; + + // The suffix used when inserting mem copies is equal to the below, probably followed by an incrementing number. + if (!suffix) { + suffix = strstr(fusedNodeArgeName.c_str(), "_DmlExecutionProvider_"); + } + + // The suffix used when inserting mem copies is equal to the below, not followed by an incrementing number. + if (!suffix) { + suffix = strstr(fusedNodeArgeName.c_str(), "_DmlExecutionProvider"); + } + + if (!suffix) { + suffix = strstr(fusedNodeArgeName.c_str(), "_token_"); + } if (suffix) { @@ -23,9 +36,9 @@ namespace Dml::GraphDescBuilder fusedNodeArgeName.begin(), fusedNodeArgeName.begin() + (suffix - fusedNodeArgeName.c_str()) ); + } else { + return fusedNodeArgeName; } - - return fusedNodeArgeName; } const std::string& GetUniqueNodeName(const onnxruntime::Node& node) From 4f887b465a607fe559df882ac2a1673d17204f79 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Mon, 27 Apr 2020 14:24:54 -0700 Subject: [PATCH 5/8] Uncomment celu test. (#3717) --- onnxruntime/test/onnx/main.cc | 1 - onnxruntime/test/python/onnx_backend_test_series.py | 1 - 2 files changed, 2 deletions(-) diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 046f8e83e8..db42bfafac 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -495,7 +495,6 @@ int real_main(int argc, char* argv[], Ort::Env& env) { {"bitshift_left_uint16", "BitShift(11) uint16 support not enabled currently"}, {"dropout_default", "result differs", {"onnxtip"}}, {"dropout_random", "result differs", {"onnxtip"}}, - {"celu", "invalid model", {"onnxtip"}}, {"maxunpool_export_with_output_shape", "Invalid output in ONNX test. See https://github.com/onnx/onnx/issues/2398"} }; diff --git a/onnxruntime/test/python/onnx_backend_test_series.py b/onnxruntime/test/python/onnx_backend_test_series.py index de036c8b1f..5b9d8760ae 100644 --- a/onnxruntime/test/python/onnx_backend_test_series.py +++ b/onnxruntime/test/python/onnx_backend_test_series.py @@ -91,7 +91,6 @@ def create_backend_test(testname=None): '^test_batchnorm_epsilon_training_mode_cpu', '^test_batchnorm_example_old_cpu', '^test_batchnorm_example_training_mode_cpu', - '^test_celu_cpu', '^test_dropout_default_cpu', '^test_dropout_random_cpu', '^test_einsum_batch_diagonal_cpu', From b990ba0059291addeb220e87a88ef9d2d7621baa Mon Sep 17 00:00:00 2001 From: ytaous <4484531+ytaous@users.noreply.github.com> Date: Mon, 27 Apr 2020 16:45:21 -0700 Subject: [PATCH 6/8] Merge from ort_training to master (#3719) * move cpu/cuda related files to coresponding cpu/cuda folder (#3668) Co-authored-by: Weixing Zhang * type cast for ratio is not necessary for dropout (#3682) Co-authored-by: Weixing Zhang * thrustallocator is not needed since cub is used directly for gather now. (#3683) Co-authored-by: Weixing Zhang * GatherND-12 Implementation (#3645) * Renamed, UT passing * Move GatherND CUDA Kerenl into onnxruntime * Merge GatherNDOpTest * Refactor Test code * Merge CPU Kernel Impl * Handle Negative Indice, Fix UT * Improve CUDA kernel to handle negative index * Minor Fixes * Preserve GatherND-1 Cuda kernel * Fix Mac build * fix UT * Fix Build * fix GatherNDOpTest.double > CUDA error cudaErrorInvalidDeviceFunction:invalid device function Co-authored-by: Sherlock Huang Co-authored-by: Peng Wang (pengwa) * Set gradient as output only for easy mode (#3694) * Support GPU Event Operators (#3653) * Add GPU event operators to support in-place updates in gradient accumulator and optimizer for modifying the tensors passing through those event operators. * Address comment and polish code * Merge shared code between CPU and GPU kernels * Move event test to a new file * Address comments * Update onnxruntime/core/providers/cuda/gpu_data_transfer.cc * fix path of cpu_featurizers_kernels.cc and cpu_featurizers_kernels.h Co-authored-by: Weixing Zhang Co-authored-by: Weixing Zhang Co-authored-by: Sherlock Co-authored-by: Sherlock Huang Co-authored-by: Peng Wang (pengwa) Co-authored-by: ashbhandare Co-authored-by: Wei-Sheng Chin Co-authored-by: Ethan Tao --- cmake/onnxruntime_providers.cmake | 10 - .../{ => cpu}/cpu_contrib_kernels.cc | 2 +- .../{ => cpu}/cpu_contrib_kernels.h | 0 .../{ => cuda}/cuda_contrib_kernels.cc | 0 .../{ => cuda}/cuda_contrib_kernels.h | 0 .../core/graph/contrib_ops/contrib_defs.cc | 71 ---- .../providers/acl/acl_execution_provider.cc | 2 +- .../providers/cpu/cpu_execution_provider.cc | 10 +- .../core/providers/cpu/tensor/gather_nd.cc | 49 ++- .../core/providers/cpu/tensor/gather_nd.h | 5 +- .../providers/cuda/cuda_execution_provider.cc | 8 +- .../core/providers/cuda/gpu_data_transfer.cc | 5 +- .../core/providers}/cuda/tensor/gather_nd.cc | 100 ++--- .../core/providers}/cuda/tensor/gather_nd.h | 19 +- .../providers}/cuda/tensor/gather_nd_impl.cu | 63 ++-- .../providers}/cuda/tensor/gather_nd_impl.h | 4 + .../{ => cpu}/cpu_featurizers_kernels.cc | 2 +- .../{ => cpu}/cpu_featurizers_kernels.h | 0 .../test/common/tensor_op_test_utils.h | 14 + .../providers/cpu/tensor/gather_nd_op_test.cc | 138 ++++++- .../core/graph/gradient_builder.cc | 8 +- .../core/graph/gradient_schema_defs.cc | 80 +++- .../core/graph/loss_func/bert_loss.cc | 4 +- .../python/orttraining_pybind_state.cc | 5 +- .../test/gradient/event_op_test.cc | 122 +++++++ .../test/gradient/gradient_checker.cc | 4 +- .../test/gradient/gradient_ops_test.cc | 119 ++---- .../cpu/tensor/gather_nd_grad_op_test.cc | 80 ++++ .../cpu/tensor/gather_nd_op_test.cc | 343 ------------------ .../training_ops/cpu/controlflow/common.h | 19 + .../cpu/controlflow/event_pool.cc | 16 +- .../training_ops/cpu/controlflow/event_pool.h | 4 + .../training_ops/cpu/controlflow/record.cc | 22 +- .../training_ops/cpu/controlflow/record.h | 3 + .../training_ops/cpu/controlflow/wait.cc | 31 +- .../training_ops/cpu/controlflow/wait.h | 5 +- .../{ => cpu}/cpu_training_kernels.cc | 4 +- .../{ => cpu}/cpu_training_kernels.h | 0 .../training_ops/cpu/tensor/gather_nd.cc | 138 ------- .../training_ops/cpu/tensor/gather_nd.h | 53 --- .../training_ops/cuda/controlflow/record.cc | 43 +++ .../training_ops/cuda/controlflow/record.h | 19 + .../training_ops/cuda/controlflow/wait.cc | 43 +++ .../training_ops/cuda/controlflow/wait.h | 19 + .../{ => cuda}/cuda_training_kernels.cc | 18 +- .../{ => cuda}/cuda_training_kernels.h | 0 .../training_ops/cuda/nn/dropout.cc | 5 +- .../training_ops/cuda/tensor/gather_grad.cc | 1 - .../cuda/tensor/gather_nd_gard_impl.cu | 47 +++ .../cuda/tensor/gather_nd_grad.cc | 67 ++++ .../training_ops/cuda/tensor/gather_nd_grad.h | 22 ++ .../cuda/tensor/thrustallocator.h | 30 -- 52 files changed, 926 insertions(+), 950 deletions(-) rename onnxruntime/contrib_ops/{ => cpu}/cpu_contrib_kernels.cc (99%) rename onnxruntime/contrib_ops/{ => cpu}/cpu_contrib_kernels.h (100%) rename onnxruntime/contrib_ops/{ => cuda}/cuda_contrib_kernels.cc (100%) rename onnxruntime/contrib_ops/{ => cuda}/cuda_contrib_kernels.h (100%) rename {orttraining/orttraining/training_ops => onnxruntime/core/providers}/cuda/tensor/gather_nd.cc (62%) rename {orttraining/orttraining/training_ops => onnxruntime/core/providers}/cuda/tensor/gather_nd.h (72%) rename {orttraining/orttraining/training_ops => onnxruntime/core/providers}/cuda/tensor/gather_nd_impl.cu (68%) rename {orttraining/orttraining/training_ops => onnxruntime/core/providers}/cuda/tensor/gather_nd_impl.h (91%) rename onnxruntime/featurizers_ops/{ => cpu}/cpu_featurizers_kernels.cc (99%) rename onnxruntime/featurizers_ops/{ => cpu}/cpu_featurizers_kernels.h (100%) create mode 100644 orttraining/orttraining/test/gradient/event_op_test.cc create mode 100644 orttraining/orttraining/test/training_ops/cpu/tensor/gather_nd_grad_op_test.cc delete mode 100644 orttraining/orttraining/test/training_ops/cpu/tensor/gather_nd_op_test.cc create mode 100644 orttraining/orttraining/training_ops/cpu/controlflow/common.h rename orttraining/orttraining/training_ops/{ => cpu}/cpu_training_kernels.cc (98%) rename orttraining/orttraining/training_ops/{ => cpu}/cpu_training_kernels.h (100%) delete mode 100644 orttraining/orttraining/training_ops/cpu/tensor/gather_nd.cc delete mode 100644 orttraining/orttraining/training_ops/cpu/tensor/gather_nd.h create mode 100644 orttraining/orttraining/training_ops/cuda/controlflow/record.cc create mode 100644 orttraining/orttraining/training_ops/cuda/controlflow/record.h create mode 100644 orttraining/orttraining/training_ops/cuda/controlflow/wait.cc create mode 100644 orttraining/orttraining/training_ops/cuda/controlflow/wait.h rename orttraining/orttraining/training_ops/{ => cuda}/cuda_training_kernels.cc (97%) rename orttraining/orttraining/training_ops/{ => cuda}/cuda_training_kernels.h (100%) create mode 100644 orttraining/orttraining/training_ops/cuda/tensor/gather_nd_gard_impl.cu create mode 100644 orttraining/orttraining/training_ops/cuda/tensor/gather_nd_grad.cc create mode 100644 orttraining/orttraining/training_ops/cuda/tensor/gather_nd_grad.h delete mode 100644 orttraining/orttraining/training_ops/cuda/tensor/thrustallocator.h diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index c91daf3c4b..389146a348 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -7,15 +7,11 @@ file(GLOB_RECURSE onnxruntime_providers_srcs CONFIGURE_DEPENDS ) file(GLOB_RECURSE onnxruntime_cpu_contrib_ops_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/contrib_ops/cpu_contrib_kernels.h" - "${ONNXRUNTIME_ROOT}/contrib_ops/cpu_contrib_kernels.cc" "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/*.h" "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/*.cc" ) file(GLOB_RECURSE onnxruntime_cuda_contrib_ops_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/contrib_ops/cuda_contrib_kernels.h" - "${ONNXRUNTIME_ROOT}/contrib_ops/cuda_contrib_kernels.cc" "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/*.h" "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/*.cc" ) @@ -26,8 +22,6 @@ file(GLOB_RECURSE onnxruntime_cuda_contrib_ops_cu_srcs CONFIGURE_DEPENDS ) file(GLOB onnxruntime_cpu_featurizers_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/featurizers_ops/cpu_featurizers_kernels.h" - "${ONNXRUNTIME_ROOT}/featurizers_ops/cpu_featurizers_kernels.cc" "${ONNXRUNTIME_ROOT}/featurizers_ops/cpu/*.h" "${ONNXRUNTIME_ROOT}/featurizers_ops/cpu/*.cc" ) @@ -95,8 +89,6 @@ endif() if (onnxruntime_ENABLE_TRAINING) file(GLOB_RECURSE onnxruntime_cpu_training_ops_srcs CONFIGURE_DEPENDS - "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu_training_kernels.h" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu_training_kernels.cc" "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/*.h" "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/*.cc" ) @@ -174,8 +166,6 @@ if (onnxruntime_USE_CUDA) if (onnxruntime_ENABLE_TRAINING) file(GLOB_RECURSE onnxruntime_cuda_training_ops_cc_srcs CONFIGURE_DEPENDS - "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda_training_kernels.h" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda_training_kernels.cc" "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/*.h" "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/*.cc" ) diff --git a/onnxruntime/contrib_ops/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc similarity index 99% rename from onnxruntime/contrib_ops/cpu_contrib_kernels.cc rename to onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index e9c9d9f873..444904aa2c 100644 --- a/onnxruntime/contrib_ops/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "contrib_ops/cpu_contrib_kernels.h" +#include "contrib_ops/cpu/cpu_contrib_kernels.h" #include "core/graph/constants.h" #include "core/mlas/inc/mlas.h" diff --git a/onnxruntime/contrib_ops/cpu_contrib_kernels.h b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.h similarity index 100% rename from onnxruntime/contrib_ops/cpu_contrib_kernels.h rename to onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.h diff --git a/onnxruntime/contrib_ops/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc similarity index 100% rename from onnxruntime/contrib_ops/cuda_contrib_kernels.cc rename to onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc diff --git a/onnxruntime/contrib_ops/cuda_contrib_kernels.h b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.h similarity index 100% rename from onnxruntime/contrib_ops/cuda_contrib_kernels.h rename to onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.h diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index ecb9f5ccb6..aaa2258fe6 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -2077,77 +2077,6 @@ Output = Dequantize(Input) -> AveragePool on fp32 data -> Quantize(output) .SetDoc(R"DOC( Given `data` tensor of rank r >= 1, and `indices` tensor of rank q >= 1, gather slices of `data` into an output tensor of rank q - 1 + r - indices[-1]. -Example 1: - data = [[0,1],[2,3]] - indices = [[0,0],[1,1]] - output = [0,3] -Example 2: - data = [[0,1],[2,3]] - indices = [[1],[0]] - output = [[2,3],[0,1]] -Example 3: - data = [[[0,1],[2,3]],[[4,5],[6,7]]] - indices = [[0,1],[1,0]] - output = [[2,3],[4,5]] -Example 4: - data = [[[0,1],[2,3]],[[4,5],[6,7]]] - indices = [[[0,1]],[[1,0]]] - output = [[[2,3]],[[4,5]]] -)DOC"); - - ONNX_CONTRIB_OPERATOR_SCHEMA(GatherND) - .SetDomain(kOnnxDomain) - .SinceVersion(1) - .Attr( - "axis", - "The number of batch dims. The gather of indexing starts from dimension of data[axis:]", - AttributeProto::INT, - static_cast(0)) - .Input(0, "data", "Tensor of rank r >= 1.", "T") - .Input(1, "indices", "Tensor of rank q >= 1.", "Tind") - .Output(0, "output", "Tensor of rank q-1+r-indices[-1].", "T") - .TypeConstraint( - "T", - OpSchema::all_tensor_types(), - "Constrain input and output types to any tensor type.") - .TypeConstraint( - "Tind", - {"tensor(int32)", "tensor(int64)"}, - "Constrain indice type to int32 or int64") - .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - propagateElemTypeFromInputToOutput(ctx, 0, 0); - if (!hasNInputShapes(ctx, 2)) { - return; - } - auto& data_shape = ctx.getInputType(0)->tensor_type().shape(); - auto& indices_shape = ctx.getInputType(1)->tensor_type().shape(); - auto data_rank = data_shape.dim_size(); - auto indices_rank = indices_shape.dim_size(); - auto axis = ctx.getAttribute("axis"); - int64_t axis_data = axis ? static_cast(axis->i()) : 0; - if (data_rank < 1 || indices_rank < 1) { - fail_shape_inference("both data and indices tensor need to have rank larger than zero."); - } - auto last_indice_dimension = indices_shape.dim(indices_rank - 1).dim_value() + axis_data; - if (last_indice_dimension > data_rank) { - fail_shape_inference("last dimension of indices must not be larger and rank of data tensor"); - } - for (int i = 0; i < indices_rank - 1; ++i) { - *ctx.getOutputType(0) - ->mutable_tensor_type() - ->mutable_shape() - ->add_dim() = indices_shape.dim(i); - } - for (int i = static_cast(last_indice_dimension); i < data_rank; ++i) { - *ctx.getOutputType(0) - ->mutable_tensor_type() - ->mutable_shape() - ->add_dim() = data_shape.dim(i); - } - }) - .SetDoc(R"DOC( -Given `data` tensor of rank r >= 1, and `indices` tensor of rank q >= 1, gather -slices of `data` into an output tensor of rank q - 1 + r - indices[-1]. Example 1: data = [[0,1],[2,3]] indices = [[0,0],[1,1]] diff --git a/onnxruntime/core/providers/acl/acl_execution_provider.cc b/onnxruntime/core/providers/acl/acl_execution_provider.cc index 8319121abf..56ff6a759c 100644 --- a/onnxruntime/core/providers/acl/acl_execution_provider.cc +++ b/onnxruntime/core/providers/acl/acl_execution_provider.cc @@ -7,7 +7,7 @@ #include "core/framework/op_kernel.h" #include "core/framework/kernel_registry.h" #include "core/framework/compute_capability.h" -#include "contrib_ops/cpu_contrib_kernels.h" +#include "contrib_ops/cpu/cpu_contrib_kernels.h" #include "acl_fwd.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 21c6b9caea..c28e966899 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -6,15 +6,15 @@ #include "core/framework/kernel_registry.h" #ifndef DISABLE_CONTRIB_OPS -#include "contrib_ops/cpu_contrib_kernels.h" +#include "contrib_ops/cpu/cpu_contrib_kernels.h" #endif #ifdef ML_FEATURIZERS -#include "featurizers_ops/cpu_featurizers_kernels.h" +#include "featurizers_ops/cpu/cpu_featurizers_kernels.h" #endif #ifdef ENABLE_TRAINING -#include "orttraining/training_ops/cpu_training_kernels.h" +#include "orttraining/training_ops/cpu/cpu_training_kernels.h" #endif #include "core/framework/compute_capability.h" @@ -435,7 +435,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, int8_t, ReduceMin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, uint8_t, ReduceMin); - +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, GatherND); Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { @@ -1082,6 +1082,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { ReduceMin)>, BuildKernelCreateInfo, + + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/cpu/tensor/gather_nd.cc b/onnxruntime/core/providers/cpu/tensor/gather_nd.cc index 9868cd3178..ae0e77e835 100644 --- a/onnxruntime/core/providers/cpu/tensor/gather_nd.cc +++ b/onnxruntime/core/providers/cpu/tensor/gather_nd.cc @@ -25,12 +25,22 @@ ONNX_OPERATOR_KERNEL_EX(GatherND, kMSDomain, 1, kCpuExecutionProvider, #endif -ONNX_CPU_OPERATOR_KERNEL(GatherND, 11, - KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::AllTensorTypes()) - // official ONNX spec only supports `int64_t` for indices - .TypeConstraint("Tind", DataTypeImpl::GetTensorType()), - GatherND); +ONNX_CPU_OPERATOR_KERNEL( + GatherND, + 11, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::AllTensorTypes()) + // official ONNX spec only supports `int64_t` for indices + .TypeConstraint("Tind", DataTypeImpl::GetTensorType()), + GatherND); + +ONNX_CPU_OPERATOR_KERNEL( + GatherND, + 12, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::AllTensorTypes()) + .TypeConstraint("Tind", DataTypeImpl::GetTensorType()), + GatherND); template Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) const { @@ -44,7 +54,7 @@ Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) con return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "indices tensor must has rank larger than 0"); } - int64_t last_indices_dimension = indices_shape[indices_shape.NumDimensions() - 1]; + int64_t last_indices_dimension = indices_shape[indices_shape.NumDimensions() - 1] + batch_dims_; if (last_indices_dimension > static_cast(input_shape.NumDimensions())) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "last dimension of indices must not be larger than rank of input tensor"); @@ -53,7 +63,7 @@ Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) con std::vector shape(indices_shape.GetDims().begin(), indices_shape.GetDims().end() - 1); shape.insert(shape.end(), input_shape.GetDims().begin() + last_indices_dimension, input_shape.GetDims().end()); auto* output_tensor = context->Output(0, TensorShape(std::move(shape))); - std::vector element_counts(last_indices_dimension, + std::vector element_counts(last_indices_dimension + batch_dims_, 0LL); // Number of elements for each input dimension #ifdef _OPENMP @@ -63,12 +73,20 @@ Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) con element_counts[i] = input_shape.SizeFromDimension(i + 1); } + auto last_dim_size = indices_shape.SizeFromDimension(indices_shape.NumDimensions() - 1); +#ifdef USE_OPENMP +#pragma omp parallel for +#endif + for (int64_t i = batch_dims_ - 1; i >= 0; --i) { + element_counts[last_indices_dimension + i] = indices_shape.SizeFromDimension(i + 1) / last_dim_size; + } + int64_t err_index = 0; p.element_bytes = input_tensor->DataType()->Size(); p.element_to_copy = input_shape.SizeFromDimension(last_indices_dimension); p.bytes_to_copy = p.element_bytes * p.element_to_copy; - const auto* indices_data = indices_tensor->Data(); - const int64_t offset_count = indices_shape.Size() / last_indices_dimension; // Times to copy + const auto* indice_offset = indices_tensor->Data(); + const int64_t offset_count = indices_shape.Size() / (last_indices_dimension - batch_dims_); // Times to copy p.element_offsets.assign(offset_count, 0LL); if (input_tensor->IsDataTypeString()) { @@ -79,12 +97,19 @@ Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) con p.output_base = static_cast(output_tensor->MutableDataRaw()); } + //Compute the element_offset #ifdef _OPENMP #pragma omp parallel for #endif for (int64_t i = 0; i < offset_count; ++i) { - for (int64_t j = 0; j < last_indices_dimension; ++j) { - auto index = *(indices_data + i * last_indices_dimension + j); + int64_t reminder = i; + for (int64_t j = 0; j < batch_dims_; ++j) { + int64_t idx = reminder / element_counts[last_indices_dimension + j]; + p.element_offsets[i] += idx * element_counts[j]; + reminder -= (idx * element_counts[last_indices_dimension + j]); + } + for (int64_t j = batch_dims_; j < last_indices_dimension; ++j) { + auto index = *(indice_offset + i * (last_indices_dimension - batch_dims_) + (j - batch_dims_)); auto upper_limit = input_shape[j]; auto lower_limit = -upper_limit; if (index < lower_limit || index >= upper_limit) { diff --git a/onnxruntime/core/providers/cpu/tensor/gather_nd.h b/onnxruntime/core/providers/cpu/tensor/gather_nd.h index a169c5ab78..135706a245 100644 --- a/onnxruntime/core/providers/cpu/tensor/gather_nd.h +++ b/onnxruntime/core/providers/cpu/tensor/gather_nd.h @@ -33,11 +33,14 @@ class GatherNDBase { template Status PrepareForCompute(OpKernelContext* context, Prepare& p) const; + int64_t batch_dims_; }; // class GatherNDBase class GatherND final : public OpKernel, protected GatherNDBase { public: - explicit GatherND(const OpKernelInfo& info) : OpKernel(info) {} + explicit GatherND(const OpKernelInfo& info) : OpKernel(info) { + info.GetAttrOrDefault("batch_dims", &batch_dims_, static_cast(0)); + } Status Compute(OpKernelContext* context) const override; private: diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 88c101d429..1611b6c5b4 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -12,11 +12,11 @@ #include "core/providers/cuda/gpu_data_transfer.h" #ifndef DISABLE_CONTRIB_OPS -#include "contrib_ops/cuda_contrib_kernels.h" +#include "contrib_ops/cuda/cuda_contrib_kernels.h" #endif #ifdef ENABLE_TRAINING -#include "orttraining/training_ops/cuda_training_kernels.h" +#include "orttraining/training_ops/cuda/cuda_training_kernels.h" #endif using namespace onnxruntime::common; @@ -769,6 +769,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, int8_t, ReduceMin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, uint8_t, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, int64_t, GatherND); + static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, @@ -1281,6 +1283,8 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/cuda/gpu_data_transfer.cc b/onnxruntime/core/providers/cuda/gpu_data_transfer.cc index 8fae7ae8b0..08ff82cee0 100644 --- a/onnxruntime/core/providers/cuda/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/cuda/gpu_data_transfer.cc @@ -36,7 +36,10 @@ common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst, int e CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyHostToDevice, streams_[exec_queue_id])); } else if (src_device.Type() == OrtDevice::GPU) { // copying between GPU, this is non-blocking - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice, streams_[kCudaStreamDefault])); + // Copy only if the two addresses are different. + if (dst_data != src_data) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice, streams_[kCudaStreamDefault])); + } } else { // copy from other CPU memory to GPU, this is blocking CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyHostToDevice)); diff --git a/orttraining/orttraining/training_ops/cuda/tensor/gather_nd.cc b/onnxruntime/core/providers/cuda/tensor/gather_nd.cc similarity index 62% rename from orttraining/orttraining/training_ops/cuda/tensor/gather_nd.cc rename to onnxruntime/core/providers/cuda/tensor/gather_nd.cc index 62f23d39ff..53e2fcd153 100644 --- a/orttraining/orttraining/training_ops/cuda/tensor/gather_nd.cc +++ b/onnxruntime/core/providers/cuda/tensor/gather_nd.cc @@ -1,14 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "orttraining/training_ops/cuda/tensor/gather_nd.h" -#include "orttraining/training_ops/cuda/tensor/gather_nd_impl.h" +#include "core/providers/cuda/tensor/gather_nd.h" +#include "core/providers/cuda/tensor/gather_nd_impl.h" #include "core/providers/cuda/shared_inc/cuda_utils.h" namespace onnxruntime { namespace cuda { -namespace { Status CheckBatchDimensionsMatch( size_t num_batch_dimensions, const std::vector>& tensor_shapes) { @@ -38,7 +37,6 @@ Status CheckBatchDimensionsMatch( return Status::OK(); } -} // namespace #define TYPED_FUNCTION_CALL_FWD(T) \ if (T_type == DataTypeImpl::GetType()) { \ @@ -50,9 +48,10 @@ Status CheckBatchDimensionsMatch( GatherNDGradImpl::MappedType>(num_slices, kernel_input_data, kernel_output_data, slice_size, input_slice_offsets_buffer.get()); \ } + template Status GatherNDBase::CommonComputeKernel( - const int64_t axis, + const int64_t batch_dims, const TensorShape& input_shape, const Tensor* kernel_input_tensor, Tensor* kernel_output_tensor, @@ -65,9 +64,9 @@ Status GatherNDBase::CommonComputeKernel( const auto num_slice_dims = indices_shape[indices_shape.NumDimensions() - 1]; const auto num_slices = indices_shape.SizeToDimension(indices_shape.NumDimensions() - 1); - const auto slice_size = input_shape.SizeFromDimension(axis + num_slice_dims); - const auto num_batches = input_shape.SizeToDimension(axis); - const auto input_batch_stride = input_shape.SizeFromDimension(axis); + const auto slice_size = input_shape.SizeFromDimension(batch_dims + num_slice_dims); + const auto num_batches = input_shape.SizeToDimension(batch_dims); + const auto input_batch_stride = input_shape.SizeFromDimension(batch_dims); const auto num_slices_per_batch = num_slices / num_batches; const TIndex* const indices_data = indices_tensor->Data(); @@ -79,7 +78,7 @@ Status GatherNDBase::CommonComputeKernel( auto running_product = slice_size; for (int64_t i = 0; i < num_slice_dims; ++i) { sizes_from_slice_dims[num_slice_dims - 1 - i] = running_product; - running_product *= input_shape[axis + num_slice_dims - 1 - i]; + running_product *= input_shape[batch_dims + num_slice_dims - 1 - i]; } } @@ -92,9 +91,11 @@ Status GatherNDBase::CommonComputeKernel( auto input_slice_offsets_buffer = GetScratchBuffer(num_slices); - // TODO error handling for invalid slice indices + TArray input_dims(input_shape.GetDims()); // TODO reuse computed slice offsets from GatherND in GatherNDGrad ComputeSliceOffsetsImpl( + batch_dims, + input_dims, num_slices, num_slices_per_batch, input_batch_stride, @@ -109,10 +110,15 @@ Status GatherNDBase::CommonComputeKernel( TYPED_FUNCTION_CALL_FWD(MLFloat16); TYPED_FUNCTION_CALL_FWD(double); } else { +#ifdef ENABLE_TRAINING MLDataType T_type = kernel_input_tensor->DataType(); TYPED_FUNCTION_CALL_BWD(float); TYPED_FUNCTION_CALL_BWD(MLFloat16); TYPED_FUNCTION_CALL_BWD(double); +#else + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Gradient computation is only supported in the training mode."); +#endif } return Status::OK(); @@ -121,11 +127,11 @@ Status GatherNDBase::CommonComputeKernel( #undef TYPED_FUNCTION_CALL_FWD #undef TYPED_FUNCTION_CALL_BWD -#define REGISTER_KERNEL_TYPED_GATHER_ND(TIndex) \ +#define REGISTER_KERNEL_TYPED_GATHER_ND(TIndex, ver) \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ GatherND, \ kOnnxDomain, \ - 1, \ + ver, \ TIndex, \ kCudaExecutionProvider, \ KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), \ @@ -133,8 +139,11 @@ Status GatherNDBase::CommonComputeKernel( .TypeConstraint("Tind", DataTypeImpl::GetTensorType()), \ GatherND); -REGISTER_KERNEL_TYPED_GATHER_ND(int64_t) -REGISTER_KERNEL_TYPED_GATHER_ND(int32_t) +// TODO: decprecate GatherND-1 after updating training models to opset-12 +#ifdef ENABLE_TRAINING +REGISTER_KERNEL_TYPED_GATHER_ND(int64_t, 1) +#endif +REGISTER_KERNEL_TYPED_GATHER_ND(int64_t, 12) template Status GatherND::ComputeInternal(OpKernelContext* context) const { @@ -151,14 +160,14 @@ Status GatherND::ComputeInternal(OpKernelContext* context) const { "indices tensor must has rank larger than 0"); } - auto last_indices_dimension = axis_ + indices_shape[indices_shape.NumDimensions() - 1]; + auto last_indices_dimension = batch_dims_ + indices_shape[indices_shape.NumDimensions() - 1]; if (last_indices_dimension > static_cast(input_shape.NumDimensions())) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "last dimension of indices must not be larger than rank of input tensor"); } ORT_RETURN_IF_ERROR(CheckBatchDimensionsMatch( - static_cast(axis_), {input_shape, indices_shape})); + static_cast(batch_dims_), {input_shape, indices_shape})); //Output shape std::vector shape(indices_shape.GetDims().begin(), indices_shape.GetDims().end() - 1); @@ -167,67 +176,10 @@ Status GatherND::ComputeInternal(OpKernelContext* context) const { auto output_tensor = context->Output(0, TensorShape(shape)); //Compute - auto status = CommonComputeKernel(axis_, input_shape, input_tensor, output_tensor, indices_shape, indices_tensor, true); + auto status = CommonComputeKernel(batch_dims_, input_shape, input_tensor, output_tensor, indices_shape, indices_tensor, true); return status; } -#define REGISTER_KERNEL_TYPED_GATHER_ND_GRAD(TIndex) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - GatherNDGrad, \ - kOnnxDomain, \ - 1, \ - TIndex, \ - kCudaExecutionProvider, \ - KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), \ - DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) \ - .TypeConstraint("Tind", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ - .InputMemoryType(0), \ - GatherNDGrad); - -REGISTER_KERNEL_TYPED_GATHER_ND_GRAD(int64_t) -REGISTER_KERNEL_TYPED_GATHER_ND_GRAD(int32_t) - -template -Status GatherNDGrad::ComputeInternal(OpKernelContext* context) const { - auto shape_tensor = context->Input(0); - auto indices_tensor = context->Input(1); - auto update_tensor = context->Input(2); - ORT_RETURN_IF_NOT(shape_tensor != nullptr); - ORT_RETURN_IF_NOT(indices_tensor != nullptr); - ORT_RETURN_IF_NOT(update_tensor != nullptr); - - auto indices_shape = indices_tensor->Shape(); - auto update_shape = update_tensor->Shape(); - - if (indices_shape.NumDimensions() == 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "indices tensor must has rank larger than 0"); - } - - auto last_indices_dimension = axis_ + indices_shape[indices_shape.NumDimensions() - 1]; - - //Output - auto shape_data = shape_tensor->Data(); - auto input_shape = TensorShape(shape_data, shape_tensor->SizeInBytes() / sizeof(shape_tensor->DataType())); - - if (last_indices_dimension > static_cast(input_shape.NumDimensions())) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "last dimension of indices must not be larger than rank of input tensor"); - } - - ORT_RETURN_IF_ERROR(CheckBatchDimensionsMatch( - static_cast(axis_), {input_shape, indices_shape, update_shape})); - - auto output_tensor = context->Output(0, input_shape); - - // TODO this memset can be expensive, a sparse tensor representation would help here - CUDA_RETURN_IF_ERROR(cudaMemsetAsync(output_tensor->MutableDataRaw(), 0, output_tensor->SizeInBytes())); - - auto status = CommonComputeKernel(axis_, input_shape, update_tensor, output_tensor, indices_shape, indices_tensor, false); - return status; -} - } // namespace cuda } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/tensor/gather_nd.h b/onnxruntime/core/providers/cuda/tensor/gather_nd.h similarity index 72% rename from orttraining/orttraining/training_ops/cuda/tensor/gather_nd.h rename to onnxruntime/core/providers/cuda/tensor/gather_nd.h index 70c429d6f9..9e5f5116a5 100644 --- a/orttraining/orttraining/training_ops/cuda/tensor/gather_nd.h +++ b/onnxruntime/core/providers/cuda/tensor/gather_nd.h @@ -10,17 +10,21 @@ namespace onnxruntime { namespace cuda { +Status CheckBatchDimensionsMatch( + size_t num_batch_dimensions, + const std::vector>& tensor_shapes); + class GatherNDBase : public CudaKernel { public: GatherNDBase(const OpKernelInfo& info) : CudaKernel(info) { - info.GetAttrOrDefault("axis", &axis_, static_cast(0)); - ORT_ENFORCE(axis_ >= 0); + info.GetAttrOrDefault("batch_dims", &batch_dims_, static_cast(0)); + ORT_ENFORCE(batch_dims_ >= 0); } protected: template Status CommonComputeKernel( - const int64_t axis, + const int64_t batch_dims, const TensorShape& input_shape, const Tensor* input_tensor, Tensor* output_tensor, @@ -28,7 +32,7 @@ class GatherNDBase : public CudaKernel { const Tensor* indices_tensor, const bool fwd) const; - int64_t axis_; + int64_t batch_dims_; }; template @@ -38,12 +42,5 @@ class GatherND final : public GatherNDBase { Status ComputeInternal(OpKernelContext* context) const override; }; -template -class GatherNDGrad final : public GatherNDBase { - public: - GatherNDGrad(const OpKernelInfo& info) : GatherNDBase(info) {} - Status ComputeInternal(OpKernelContext* context) const override; -}; - } // namespace cuda } // namespace onnxruntime \ No newline at end of file diff --git a/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_impl.cu b/onnxruntime/core/providers/cuda/tensor/gather_nd_impl.cu similarity index 68% rename from orttraining/orttraining/training_ops/cuda/tensor/gather_nd_impl.cu rename to onnxruntime/core/providers/cuda/tensor/gather_nd_impl.cu index d754dcfe9f..91b0169e68 100644 --- a/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/gather_nd_impl.cu @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "orttraining/training_ops/cuda/tensor/gather_nd_impl.h" +#include "core/providers/cuda/tensor/gather_nd_impl.h" #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/atomic/common.cuh" @@ -11,6 +11,8 @@ namespace cuda { template __global__ void _ComputeSliceOffsetsKernel( + const int64_t batch_dims, + const TArray input_dims, const size_t num_slices, const size_t num_slices_per_batch, const size_t input_batch_stride, @@ -26,7 +28,12 @@ __global__ void _ComputeSliceOffsetsKernel( const TIndex* const slice_indices = indices_data + slice_idx * num_slice_dims; size_t relative_slice_offset = 0; for (size_t dim_idx = 0; dim_idx < num_slice_dims; ++dim_idx) { - relative_slice_offset += static_cast(slice_indices[dim_idx]) * sizes_from_slice_dims_data[dim_idx]; + int64_t index = static_cast(slice_indices[dim_idx]); + const size_t input_dim_idx = batch_dims + dim_idx; + CUDA_KERNEL_ASSERT(index >= -input_dims[input_dim_idx] && index < input_dims[input_dim_idx]); + if (index < 0) index += input_dims[input_dim_idx]; + + relative_slice_offset += index * sizes_from_slice_dims_data[dim_idx]; } input_slice_offsets_data[slice_idx] = base_offset + relative_slice_offset; @@ -44,21 +51,10 @@ __global__ void _GatherNDKernel( output_data[i] = input_data[slice_offset + i % slice_size]; }; -template -__global__ void _GatherNDGradKernel( - const size_t num_slices, - const T* update_data, - T* output_data, - const size_t slice_size, - const int64_t* slice_offsets) { - CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(i, num_slices * slice_size); - uint64_t slice_offset = slice_offsets[i / slice_size]; - size_t j = i % slice_size; - atomic_add(output_data + slice_offset + j, update_data[i]); -}; - template void ComputeSliceOffsetsImpl( + const int64_t batch_dims, + const TArray input_dims, const size_t num_slices, const size_t num_slices_per_batch, const size_t input_batch_stride, @@ -68,6 +64,8 @@ void ComputeSliceOffsetsImpl( int64_t* const input_slice_offsets_data) { // num_slices elements const auto blocks_per_grid = CeilDiv(num_slices, GridDim::maxThreadsPerBlock); _ComputeSliceOffsetsKernel<<>>( + batch_dims, + input_dims, num_slices, num_slices_per_batch, input_batch_stride, @@ -89,45 +87,28 @@ void GatherNDImpl( num_slices, static_cast(input_data), static_cast(output_data), slice_size, input_slice_offsets_data); } -template -void GatherNDGradImpl( - const size_t num_slices, - const void* update_data, - void* output_data, - const size_t slice_size, - const int64_t* input_slice_offsets_data) { - const auto blocks_per_grid = CeilDiv(num_slices * slice_size, GridDim::maxThreadsPerBlock); - _GatherNDGradKernel<<>>( - num_slices, static_cast(update_data), static_cast(output_data), slice_size, input_slice_offsets_data); -} - #define SPECIALIZED_COMPUTE_SLICE_OFFSETS_IMPL(TIndex) \ template void ComputeSliceOffsetsImpl( \ + const int64_t batch_dims, \ + const TArray input_dims, \ const size_t num_slices, \ const size_t num_slices_per_batch, \ const size_t input_batch_stride, \ const size_t num_slice_dims, \ const int64_t* const sizes_from_slice_dims_data, \ const TIndex* const indices_data, \ - int64_t* const input_slice_offsets_data) + int64_t* const input_slice_offsets_data); #define SPECIALIZED_IMPL(T) \ - template void GatherNDImpl(const size_t num_slices, const void* input_data, void* output_data, const size_t slice_size, const int64_t* input_slice_offsets_data) + template void GatherNDImpl(const size_t num_slices, const void* input_data, void* output_data, const size_t slice_size, const int64_t* input_slice_offsets_data); -#define SPECIALIZED_GRAD_IMPL(T) \ - template void GatherNDGradImpl(const size_t num_slices, const void* update_data, void* output_data, const size_t slice_size, const int64_t* input_slice_offsets_data) +SPECIALIZED_COMPUTE_SLICE_OFFSETS_IMPL(int32_t) +SPECIALIZED_COMPUTE_SLICE_OFFSETS_IMPL(int64_t) -SPECIALIZED_COMPUTE_SLICE_OFFSETS_IMPL(int32_t); -SPECIALIZED_COMPUTE_SLICE_OFFSETS_IMPL(int64_t); - -SPECIALIZED_IMPL(float); -SPECIALIZED_GRAD_IMPL(float); +SPECIALIZED_IMPL(float) #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 -SPECIALIZED_IMPL(half); -SPECIALIZED_GRAD_IMPL(half); - -SPECIALIZED_IMPL(double); -SPECIALIZED_GRAD_IMPL(double); +SPECIALIZED_IMPL(half) +SPECIALIZED_IMPL(double) #endif } // namespace cuda diff --git a/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_impl.h b/onnxruntime/core/providers/cuda/tensor/gather_nd_impl.h similarity index 91% rename from orttraining/orttraining/training_ops/cuda/tensor/gather_nd_impl.h rename to onnxruntime/core/providers/cuda/tensor/gather_nd_impl.h index aa315d9fa3..e989fb330a 100644 --- a/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/gather_nd_impl.h @@ -9,6 +9,8 @@ namespace cuda { template void ComputeSliceOffsetsImpl( + const int64_t batch_dims, + const TArray input_dims, const size_t num_slices, const size_t num_slices_per_batch, const size_t input_batch_stride, @@ -25,6 +27,7 @@ void GatherNDImpl( const size_t slice_size, const int64_t* input_slice_offsets_data); +#ifdef ENABLE_TRAINING template void GatherNDGradImpl( const size_t num_slices, @@ -32,6 +35,7 @@ void GatherNDGradImpl( void* output_data, const size_t slice_size, const int64_t* input_slice_offsets_data); +#endif } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/featurizers_ops/cpu_featurizers_kernels.cc b/onnxruntime/featurizers_ops/cpu/cpu_featurizers_kernels.cc similarity index 99% rename from onnxruntime/featurizers_ops/cpu_featurizers_kernels.cc rename to onnxruntime/featurizers_ops/cpu/cpu_featurizers_kernels.cc index 854dcc9063..4a7c10dd1f 100644 --- a/onnxruntime/featurizers_ops/cpu_featurizers_kernels.cc +++ b/onnxruntime/featurizers_ops/cpu/cpu_featurizers_kernels.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "featurizers_ops/cpu_featurizers_kernels.h" +#include "featurizers_ops/cpu/cpu_featurizers_kernels.h" #include "core/graph/constants.h" #include "core/framework/data_types.h" diff --git a/onnxruntime/featurizers_ops/cpu_featurizers_kernels.h b/onnxruntime/featurizers_ops/cpu/cpu_featurizers_kernels.h similarity index 100% rename from onnxruntime/featurizers_ops/cpu_featurizers_kernels.h rename to onnxruntime/featurizers_ops/cpu/cpu_featurizers_kernels.h diff --git a/onnxruntime/test/common/tensor_op_test_utils.h b/onnxruntime/test/common/tensor_op_test_utils.h index 1122e3cb56..5259d9c5f2 100644 --- a/onnxruntime/test/common/tensor_op_test_utils.h +++ b/onnxruntime/test/common/tensor_op_test_utils.h @@ -55,6 +55,20 @@ inline std::vector FillZeros(const std::vector& dims) { return val; } +// Returns a vector of `count` values which start at `start` and change by increments of `step`. +template +inline std::vector ValueRange( + size_t count, T start = static_cast(0), T step = static_cast(1)) { + std::vector result; + result.reserve(count); + T curr = start; + for (size_t i = 0; i < count; ++i) { + result.emplace_back(curr); + curr += step; + } + return result; +} + inline std::pair MeanStdev(std::vector& v) { float sum = std::accumulate(v.begin(), v.end(), 0.0f); float mean = sum / v.size(); diff --git a/onnxruntime/test/providers/cpu/tensor/gather_nd_op_test.cc b/onnxruntime/test/providers/cpu/tensor/gather_nd_op_test.cc index 91164a608f..7b15687120 100644 --- a/onnxruntime/test/providers/cpu/tensor/gather_nd_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/gather_nd_op_test.cc @@ -3,6 +3,8 @@ #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/common/tensor_op_test_utils.h" namespace onnxruntime { namespace test { @@ -18,15 +20,22 @@ static void RunTest(const std::vector& input_dims, const std::initializ test1.AddOutput("output", output_dims, output); test1.Run(); -#ifndef DISABLE_CONTRIB_OPS - - // MSFT domain opset-1 (contrib op) - OpTester test2("GatherND", 1, kMSDomain); + // ONNX domain opset-12 + OpTester test2("GatherND", 12); test2.AddInput("data", input_dims, input); test2.AddInput("indices", indices_dims, indices); test2.AddOutput("output", output_dims, output); test2.Run(); +#ifndef DISABLE_CONTRIB_OPS + + // MSFT domain opset-1 (contrib op) + OpTester test3("GatherND", 1, kMSDomain); + test3.AddInput("data", input_dims, input); + test3.AddInput("indices", indices_dims, indices); + test3.AddOutput("output", output_dims, output); + test3.Run(); + #endif } @@ -70,11 +79,21 @@ TEST(GatherNDOpTest, int64_t) { } TEST(GatherNDOpTest, float) { + if (!HasCudaEnvironment(600 /*min_cuda_architecture*/)) return; + RunTest({2, 2}, {0.0f, 0.1f, 0.2f, 0.3f}, {2, 1}, {1LL, 0LL}, {2, 2}, {0.2f, 0.3f, 0.0f, 0.1f}); + + // with negative indices + RunTest({2, 2}, {0.0f, 0.1f, 0.2f, 0.3f}, {2, 1}, {-1LL, 0LL}, {2, 2}, {0.2f, 0.3f, 0.0f, 0.1f}); } TEST(GatherNDOpTest, double) { + if (!HasCudaEnvironment(600 /*min_cuda_architecture*/)) return; + RunTest({2, 2}, {0.0, 0.1, 0.2, 0.3}, {2, 1}, {1LL, 0LL}, {2, 2}, {0.2, 0.3, 0.0, 0.1}); + + // with negative indices + RunTest({2, 2}, {0.0, 0.1, 0.2, 0.3}, {2, 1}, {-1LL, 0LL}, {2, 2}, {0.2, 0.3, 0.0, 0.1}); } TEST(GatherNDOpTest, int8_t) { @@ -114,5 +133,116 @@ TEST(GatherNDOpTest, ContribOpInt32Indices) { #endif +TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_0) { + OpTester test("GatherND", 12, kOnnxDomain); + test.AddAttribute("batch_dims", 0); + test.AddInput("data", {2, 3, 4}, ValueRange(24, 1.0f)); + test.AddInput("indices", {3, 2, 2}, {0LL, 1LL, 0LL, 2LL, 1LL, 0LL, 0LL, 0LL, 1LL, 1LL, 1LL, 2LL}); + test.AddOutput("output", {3, 2, 4}, {5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 1.0, 2.0, 3.0, 4.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}); + test.Run(); +} + +TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_1) { + OpTester test("GatherND", 12, kOnnxDomain); + test.AddAttribute("batch_dims", 1); + test.AddInput("data", {2, 3, 4}, ValueRange(24, 1.0f)); + test.AddInput("indices", {2, 2, 2}, {0LL, 1LL, 0LL, 2LL, 1LL, 0LL, 0LL, 0LL}); + test.AddOutput("output", {2, 2}, {2.0, 3.0, 17.0, 13.0}); + test.Run(); +} + +TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_2) { + OpTester test("GatherND", 12, kOnnxDomain); + test.AddAttribute("batch_dims", 1); + test.AddInput("data", {2, 2, 2}, ValueRange(8, 0.0f, 0.1f)); + test.AddInput("indices", {2, 1}, {1LL, 0LL}); + test.AddOutput("output", {2, 2}, {0.2f, 0.3f, 0.4f, 0.5f}); + test.Run(); +} + +#ifdef USE_CUDA +#if __CUDA_ARCH__ >= 600 +TEST(GatherNDOpTest, GatherND_slice_double_batch_dims_3) { + OpTester test("GatherND", 12, kOnnxDomain); + test.AddAttribute("batch_dims", 1); + test.AddInput("data", {2, 2, 2}, ValueRange(8, 0.0f, 0.1f)); + test.AddInput("indices", {2, 1, 1}, {1LL, 0LL}); + test.AddOutput("output", {2, 1, 2}, {0.2f, 0.3f, 0.4f, 0.5f}); + test.Run(); +} + +TEST(GatherNDOpTest, GatherND_slice_double) { + OpTester test("GatherND", 12, kOnnxDomain); + test.AddInput("data", {2, 2}, {0.0f, 0.1f, 0.2f, 0.3f}); + test.AddInput("indices", {2, 1}, {1LL, 0LL}); + test.AddOutput("output", {2, 2}, {0.2f, 0.3f, 0.0f, 0.1f}); + test.Run(); +} +#endif +#endif + +TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_4) { + OpTester test("GatherND", 12, kOnnxDomain); + test.AddAttribute("batch_dims", 1); + test.AddInput("data", {2, 2, 2}, ValueRange(8, 0.0f, 0.1f)); + test.AddInput("indices", {2, 1, 2}, {1LL, 0LL, 0LL, 1LL}); + test.AddOutput("output", {2, 1}, {0.2f, 0.5f}); + test.Run(); +} + +#ifdef USE_CUDA + +TEST(GatherNDOpTest, GatherND_slice_double_batch_dims_3) { + if (!HasCudaEnvironment(600 /*min_cuda_architecture*/)) return; + + OpTester test("GatherND", 12, kOnnxDomain); + test.AddAttribute("batch_dims", 1); + test.AddInput("data", {2, 2, 2}, ValueRange(8, 0.0, 0.1)); + test.AddInput("indices", {2, 1, 1}, {1LL, 0LL}); + test.AddOutput("output", {2, 1, 2}, {0.2f, 0.3f, 0.4f, 0.5f}); + test.Run(); +} + +TEST(GatherNDOpTest, GatherND_slice_half) { + if (!HasCudaEnvironment(600 /*min_cuda_architecture*/)) return; + + OpTester test("GatherND", 12, kOnnxDomain); + std::vector data_f({0.0f, 0.1f, 0.2f, 0.3f}); + std::vector outputs_f({0.2f, 0.3f, 0.0f, 0.1f}); + std::vector data(4); + std::vector outputs(4); + ConvertFloatToMLFloat16(data_f.data(), data.data(), 4); + ConvertFloatToMLFloat16(outputs_f.data(), outputs.data(), 4); + test.AddInput("data", {2, 2}, data); + test.AddInput("indices", {2, 1}, {1LL, 0LL}); + test.AddOutput("output", {2, 2}, outputs); + test.Run(); +} + +TEST(GatherNDOpTest, GatherND_batch_dims_of_2) { + OpTester test("GatherND", 12, kOnnxDomain); + test.AddAttribute("batch_dims", 2); + test.AddInput("data", {2, 2, 2, 2, 3}, ValueRange(48)); + test.AddInput( + "indices", {2, 2, 1, 2}, + { + 0, 0, // batch 0 + 1, 0, // batch 1 + 1, 1, // batch 2 + 0, 1, // batch 3 + }); + test.AddOutput( + "output", {2, 2, 1, 3}, + { + 0, 1, 2, // batch 0 + 18, 19, 20, // batch 1 + 33, 34, 35, // batch 2 + 39, 40, 41, // batch 3 + }); + test.Run(); +} + +#endif + } // namespace test } // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 609ff30304..0c19e809ca 100644 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -438,16 +438,16 @@ IMPLEMENT_GRADIENT_BUILDER(GetConcatGradient) { IMPLEMENT_GRADIENT_BUILDER(GetGatherNDGradient) { auto attributes = SrcNodeAttributes(); - ORT_ENFORCE(attributes.at("axis").has_i()); - auto axis = attributes.at("axis").i(); + ORT_ENFORCE(attributes.at("batch_dims").has_i()); + auto batch_dims = attributes.at("batch_dims").i(); return std::vector{ NodeDef("Shape", {I(0)}, {IA("x_shape")}), - NodeDef("GatherNDGrad", + NodeDef(OpDef{"GatherNDGrad", kMSDomain, 1}, {IA("x_shape"), I(1), GO(0)}, {GI(0)}, - {MakeAttribute("axis", axis)})}; + {MakeAttribute("batch_dims", batch_dims)})}; }; IMPLEMENT_GRADIENT_BUILDER(GetReshapeGradient) { diff --git a/orttraining/orttraining/core/graph/gradient_schema_defs.cc b/orttraining/orttraining/core/graph/gradient_schema_defs.cc index 9249dd8410..8657a27f63 100644 --- a/orttraining/orttraining/core/graph/gradient_schema_defs.cc +++ b/orttraining/orttraining/core/graph/gradient_schema_defs.cc @@ -698,12 +698,84 @@ void RegisterGradientSchemas() { propagateShapeAndTypeFromFirstInput(ctx); }); - ONNX_CONTRIB_OPERATOR_SCHEMA(GatherNDGrad) + // TODO: Depreacate this schema when training support is udpated to opset-12 + ONNX_CONTRIB_OPERATOR_SCHEMA(GatherND) .SetDomain(kOnnxDomain) .SinceVersion(1) .Attr( - "axis", - "The number of batch dims. The gather of indexing starts from dimension of data[axis+1:]", + "batch_dims", + "The number of batch dims. The gather of indexing starts from dimension of data[batch_dims:]", + AttributeProto::INT, + static_cast(0)) + .Input(0, "data", "Tensor of rank r >= 1.", "T") + .Input(1, "indices", "Tensor of rank q >= 1.", "Tind") + .Output(0, "output", "Tensor of rank q-1+r-indices[-1].", "T") + .TypeConstraint( + "T", + OpSchema::all_tensor_types(), + "Constrain input and output types to any tensor type.") + .TypeConstraint( + "Tind", + {"tensor(int32)", "tensor(int64)"}, + "Constrain indice type to int32 or int64") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + if (!hasNInputShapes(ctx, 2)) { + return; + } + auto& data_shape = ctx.getInputType(0)->tensor_type().shape(); + auto& indices_shape = ctx.getInputType(1)->tensor_type().shape(); + auto data_rank = data_shape.dim_size(); + auto indices_rank = indices_shape.dim_size(); + auto batch_dims = ctx.getAttribute("batch_dims"); + int64_t batch_dims_data = batch_dims ? static_cast(batch_dims->i()) : 0; + if (data_rank < 1 || indices_rank < 1) { + fail_shape_inference("both data and indices tensor need to have rank larger than zero."); + } + auto last_indice_dimension = indices_shape.dim(indices_rank - 1).dim_value() + batch_dims_data; + if (last_indice_dimension > data_rank) { + fail_shape_inference("last dimension of indices must not be larger and rank of data tensor"); + } + for (int i = 0; i < indices_rank - 1; ++i) { + *ctx.getOutputType(0) + ->mutable_tensor_type() + ->mutable_shape() + ->add_dim() = indices_shape.dim(i); + } + for (int i = static_cast(last_indice_dimension); i < data_rank; ++i) { + *ctx.getOutputType(0) + ->mutable_tensor_type() + ->mutable_shape() + ->add_dim() = data_shape.dim(i); + } + }) + .SetDoc(R"DOC( +Given `data` tensor of rank r >= 1, and `indices` tensor of rank q >= 1, gather +slices of `data` into an output tensor of rank q - 1 + r - indices[-1]. +Example 1: + data = [[0,1],[2,3]] + indices = [[0,0],[1,1]] + output = [0,3] +Example 2: + data = [[0,1],[2,3]] + indices = [[1],[0]] + output = [[2,3],[0,1]] +Example 3: + data = [[[0,1],[2,3]],[[4,5],[6,7]]] + indices = [[0,1],[1,0]] + output = [[2,3],[4,5]] +Example 4: + data = [[[0,1],[2,3]],[[4,5],[6,7]]] + indices = [[[0,1]],[[1,0]]] + output = [[[2,3]],[[4,5]]] +)DOC"); + + ONNX_CONTRIB_OPERATOR_SCHEMA(GatherNDGrad) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr( + "batch_dims", + "The number of batch dims. The gather of indexing starts from dimension of data[batch_dims+1:]", AttributeProto::INT, static_cast(0)) .Input(0, "shape", "The shape of source data input of GatherND.", "T1") @@ -716,7 +788,7 @@ void RegisterGradientSchemas() { "Constrain input and output types to any tensor type.") .TypeConstraint( "Tind", - {"tensor(int32)", "tensor(int64)"}, + {"tensor(int64)"}, "Constrain indice type to int32 or int64") .TypeConstraint( "T1", diff --git a/orttraining/orttraining/core/graph/loss_func/bert_loss.cc b/orttraining/orttraining/core/graph/loss_func/bert_loss.cc index a75f9faa5a..b3237c2405 100644 --- a/orttraining/orttraining/core/graph/loss_func/bert_loss.cc +++ b/orttraining/orttraining/core/graph/loss_func/bert_loss.cc @@ -83,10 +83,10 @@ GraphAugmenter::GraphDefs BertLoss::operator()(const Graph& graph, const LossFun "Mask_LM_Positions_Unsqueezed")); TypeProto* gathered_prediction_type_proto = GetGatheredPredictionTypeProto(prediction_arg, graph_defs); - new_nodes.emplace_back(NodeDef("GatherND", + new_nodes.emplace_back(NodeDef(OpDef{"GatherND", kOnnxDomain, 12}, {ArgDef(prediction_masked_lm), ArgDef("masked_lm_positions_unsqueezed")}, {ArgDef("gathered_prediction", gathered_prediction_type_proto)}, - {ONNX_NAMESPACE::MakeAttribute("axis", static_cast(1))}, + {ONNX_NAMESPACE::MakeAttribute("batch_dims", static_cast(1))}, "GATHERED_LM")); TypeProto* masked_lm_float_type_proto = GetMaskedLMTypeProto(prediction_arg, diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index 8e68e15d5c..4628bad416 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -94,7 +94,7 @@ TrainingConfigurationResult ConfigureSessionForTraining( config.weight_names_to_not_train = parameters.weights_not_to_train; config.immutable_weights = parameters.immutable_weights; - config.set_gradients_as_graph_outputs = true; + config.set_gradients_as_graph_outputs = false; config.gradient_accumulation_steps = parameters.gradient_accumulation_steps; @@ -115,6 +115,7 @@ TrainingConfigurationResult ConfigureSessionForTraining( config.loss_name = parameters.loss_output_name; if (!parameters.training_optimizer_name.empty()) { + config.set_gradients_as_graph_outputs = true; training::TrainingSession::TrainingConfiguration::OptimizerConfiguration opt{}; opt.name = parameters.training_optimizer_name; opt.learning_rate_input_name = parameters.lr_params_feed_name; @@ -276,4 +277,4 @@ void addObjectMethodsForTraining(py::module& m) { } } // namespace python -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/orttraining/orttraining/test/gradient/event_op_test.cc b/orttraining/orttraining/test/gradient/event_op_test.cc new file mode 100644 index 0000000000..881ed4fa5d --- /dev/null +++ b/orttraining/orttraining/test/gradient/event_op_test.cc @@ -0,0 +1,122 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" + +#include "test/common/tensor_op_test_utils.h" +#include "test/providers/provider_test_utils.h" +#include "test/util/include/test_random_seed.h" +#include "test/util/include/default_providers.h" + +#include "onnx/defs/attr_proto_util.h" + +namespace onnxruntime { +namespace test { + +// Run GPU op for GPU build. Otherwise, run GPU op. +void run_provider_specific_optest(OpTester& tester) { + RunOptions run_option; +#ifdef USE_CUDA + std::vector> providers; + providers.push_back(DefaultCudaExecutionProvider()); +#else + std::vector> providers; + providers.push_back(DefaultCpuExecutionProvider()); +#endif + tester.Run( + OpTester::ExpectResult::kExpectSuccess, + "", + std::unordered_set(), + &run_option, + &providers); +} + +void record_event(int64_t event_id) { + OpTester test_record("RecordEvent", 1, onnxruntime::kMSDomain); + test_record.AddInput("EventIdentifier", {}, {event_id}); + test_record.AddInput("InputSignal", {}, {true}); + test_record.AddOutput("OutputSignal", {}, {true}); + run_provider_specific_optest(test_record); +} + +void record_event_multiple_inputs_and_outputs(int64_t event_id) { + OpTester test_record("RecordEvent", 1, onnxruntime::kMSDomain); + test_record.AddInput("EventIdentifier", {}, {event_id}); + test_record.AddInput("InputSignal", {}, {true}); + test_record.AddInput("Input1", {3}, {9.4f, 1.7f, 3.6f}); + test_record.AddInput("Input2", {1}, {1.6f}); + test_record.AddOutput("OutputSignal", {}, {true}); + test_record.AddOutput("Output1", {3}, {9.4f, 1.7f, 3.6f}); + test_record.AddOutput("Output2", {1}, {1.6f}); + run_provider_specific_optest(test_record); +} + +void wait_event(int64_t event_id) { + OpTester test_wait("WaitEvent", 1, onnxruntime::kMSDomain); + test_wait.AddInput("EventIdentifier", {}, {event_id}); + test_wait.AddInput("InputSignal", {}, {true}); + test_wait.AddOutput("OutputSignal", {}, {true}); + run_provider_specific_optest(test_wait); +} + +void wait_event_multiple_inputs_and_outputs(int64_t event_id) { + OpTester test_wait("WaitEvent", 1, onnxruntime::kMSDomain); + test_wait.AddInput("EventIdentifier", {}, {event_id}); + test_wait.AddInput("InputSignal", {}, {true}); + test_wait.AddInput("Input1", {1}, {1.6f}); + test_wait.AddInput("Input2", {3}, {9.4f, 1.7f, 3.6f}); + test_wait.AddOutput("OutputSignal", {}, {true}); + test_wait.AddOutput("output1", {1}, {1.6f}); + test_wait.AddOutput("output2", {3}, {9.4f, 1.7f, 3.6f}); + run_provider_specific_optest(test_wait); +} + +TEST(Synchronization, RecordAndWaitEvent) { + const int64_t event_id = static_cast(1736); + record_event(event_id); + wait_event(event_id); +} + +TEST(Synchronization, WaitNullEvent) { + wait_event(-1); +} + +TEST(Synchronization, RecordAndWaitEventMultipleInputsAndOutputs) { + const int64_t event_id = static_cast(995); + record_event_multiple_inputs_and_outputs(event_id); + wait_event_multiple_inputs_and_outputs(event_id); +} + +TEST(Synchronization, WaitAndRecordEvent) { + const int64_t event_id = static_cast(1228); + std::thread waiting_thread(wait_event, event_id); + std::this_thread::sleep_for(std::chrono::milliseconds(5)); + std::thread recording_thread(record_event, event_id); + + waiting_thread.join(); + recording_thread.join(); +} + +TEST(Synchronization, WaitAndRecordEventMany) { + const size_t event_count = 16; + for (int i = 0; i < 8; ++i) { + std::thread thread_pool[2 * event_count]; + for (int j = 0; j < static_cast(event_count); ++j) { + thread_pool[j] = std::thread(wait_event, j); + thread_pool[j + event_count] = std::thread(record_event, j); + } + for (size_t j = 0; j < event_count; ++j) { + thread_pool[j].join(); + thread_pool[j + event_count].join(); + } + } +} + +} // namespace test +} // namespace onnxruntime \ No newline at end of file diff --git a/orttraining/orttraining/test/gradient/gradient_checker.cc b/orttraining/orttraining/test/gradient/gradient_checker.cc index c12fae5cd1..4fc263a461 100644 --- a/orttraining/orttraining/test/gradient/gradient_checker.cc +++ b/orttraining/orttraining/test/gradient/gradient_checker.cc @@ -457,8 +457,8 @@ inline Status GradientChecker::ComputeGradientErrorInternal( // TODO: These 4 test failed at following ORT_ENFORCE. need investigate before enable it. //GradientCheckerTest.MatMulGrad //GradientCheckerTest.GemmGrad - //GradientCheckerTest.GatherNDGrad_int64_indice_repeat_float_data - //GradientCheckerTest.GatherNDGrad_int64_indice_unique_float_data + //GradientCheckerTest.GatherNDGrad_repeat_float_data + //GradientCheckerTest.GatherNDGrad_unique_float_data //auto jac_t = jacobian_ts[j]; //ORT_ENFORCE(std::all_of( // &jac_t[0], &jac_t[0] + x_info.shape.Size(), [](auto dx) { return dx == 0; })); diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index 556160bd58..a407d07e7d 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -1535,68 +1535,62 @@ TEST(GradientCheckerTest, DISABLED_DropoutGrad) { } } -TEST(GradientCheckerTest, GatherNDGrad_int64_indice_repeat_float_data) { +TEST(GradientCheckerTest, GatherNDGrad_repeat_float_data) { float max_error; GradientChecker gradient_checker; - OpDef op_def{"GatherND"}; + OpDef op_def{"GatherND", kOnnxDomain, 12}; TensorInfo x_info({2, 2}, true); TensorInfo indice_info({2, 2}, false, nullptr, DataTypeImpl::GetTensorType()); std::vector> x_datas = {{0, 1, 2, 3}, {1, 1, 1, 1}}; TensorInfo y_info({2}, true); - int64_t axis = 0; + int64_t batch_dims = 0; - gradient_checker.ComputeGradientError(op_def, {x_info, indice_info}, {y_info}, &max_error, x_datas, {MakeAttribute("axis", axis)}); + gradient_checker.ComputeGradientError(op_def, {x_info, indice_info}, {y_info}, &max_error, x_datas, {MakeAttribute("batch_dims", batch_dims)}); EXPECT_IS_TINY(max_error); } -TEST(GradientCheckerTest, GatherNDGrad_int64_indice_unique_float_data) { +TEST(GradientCheckerTest, GatherNDGrad_unique_float_data) { float max_error; GradientChecker gradient_checker; - OpDef op_def{"GatherND"}; + OpDef op_def{"GatherND", kOnnxDomain, 12}; - TensorInfo x_info({2, 2}, true); - TensorInfo indice_info({2, 2}, false, nullptr, DataTypeImpl::GetTensorType()); - std::vector> x_datas = {{0, 1, 2, 3}, {0, 1, 1, 0}}; + { + TensorInfo x_info({2, 2}, true); + TensorInfo indice_info({2, 2}, false, nullptr, DataTypeImpl::GetTensorType()); + std::vector> x_datas = {{0, 1, 2, 3}, {0, 1, 1, 0}}; - TensorInfo y_info({2}, true); - int64_t axis = 0; + TensorInfo y_info({2}, true); + int64_t batch_dims = 0; - gradient_checker.ComputeGradientError(op_def, {x_info, indice_info}, {y_info}, &max_error, x_datas, {MakeAttribute("axis", axis)}); - EXPECT_IS_TINY(max_error); -} + gradient_checker.ComputeGradientError(op_def, {x_info, indice_info}, {y_info}, &max_error, x_datas, {MakeAttribute("batch_dims", batch_dims)}); + EXPECT_IS_TINY(max_error); + } -TEST(GradientCheckerTest, GatherNDGrad_int32_indice_unique_float_data) { - float max_error; - GradientChecker gradient_checker; - OpDef op_def{"GatherND"}; + { + TensorInfo x_info({2, 2, 3}, true); + TensorInfo indice_info({2, 1}, false, nullptr, DataTypeImpl::GetTensorType()); + std::vector> x_datas = {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {1, 0}}; - TensorInfo x_info({2, 2, 3}, true); - TensorInfo indice_info({2, 1}, false, nullptr, DataTypeImpl::GetTensorType()); - std::vector> x_datas = {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {1, 0}}; + TensorInfo y_info({2, 3}, true); + int64_t batch_dims = 1; - TensorInfo y_info({2, 3}, true); - int64_t axis = 1; + gradient_checker.ComputeGradientError(op_def, {x_info, indice_info}, {y_info}, &max_error, x_datas, {MakeAttribute("batch_dims", batch_dims)}); + EXPECT_IS_TINY(max_error); + } - gradient_checker.ComputeGradientError(op_def, {x_info, indice_info}, {y_info}, &max_error, x_datas, {MakeAttribute("axis", axis)}); - EXPECT_IS_TINY(max_error); -} + { + TensorInfo x_info({2, 2, 3}, true); + TensorInfo indice_info({2, 2, 1}, false, nullptr, DataTypeImpl::GetTensorType()); + std::vector> x_datas = {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {1, 0, 2, 1}}; -TEST(GradientCheckerTest, GatherNDGrad_int32_indice_unique_float_data_axis_2) { - float max_error; - GradientChecker gradient_checker; - OpDef op_def{"GatherND"}; + TensorInfo y_info({2, 2}, true); + int64_t batch_dims = 2; - TensorInfo x_info({2, 2, 3}, true); - TensorInfo indice_info({2, 2, 1}, false, nullptr, DataTypeImpl::GetTensorType()); - std::vector> x_datas = {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {1, 0, 2, 1}}; - - TensorInfo y_info({2, 2}, true); - int64_t axis = 2; - - gradient_checker.ComputeGradientError(op_def, {x_info, indice_info}, {y_info}, &max_error, x_datas, {MakeAttribute("axis", axis)}); - EXPECT_IS_TINY(max_error); + gradient_checker.ComputeGradientError(op_def, {x_info, indice_info}, {y_info}, &max_error, x_datas, {MakeAttribute("batch_dims", batch_dims)}); + EXPECT_IS_TINY(max_error); + } } TEST(GradientCheckerTest, GatherElementsGradWithDuplicateUpdate) { @@ -1804,53 +1798,6 @@ TEST(GradientCheckerTest, SliceGrad) { } } -void record_event(int64_t event_id) { - OpTester test_record("RecordEvent", 1, onnxruntime::kMSDomain); - test_record.AddInput("EventIdentifier", {}, {event_id}); - test_record.AddInput("InputSignal", {}, {true}); - test_record.AddOutput("OutputSignal", {}, {true}); - test_record.Run(); -} - -void wait_event(int64_t event_id) { - OpTester test_wait("WaitEvent", 1, onnxruntime::kMSDomain); - test_wait.AddInput("EventIdentifier", {}, {event_id}); - test_wait.AddInput("InputSignal", {}, {true}); - test_wait.AddOutput("OutputSignal", {}, {true}); - test_wait.Run(); -} - -TEST(Synchronization, RecordAndWaitEvent) { - const int64_t event_id = static_cast(1736); - record_event(event_id); - wait_event(event_id); -} - -TEST(Synchronization, WaitAndRecordEvent) { - const int64_t event_id = static_cast(1228); - std::thread waiting_thread(wait_event, event_id); - std::this_thread::sleep_for(std::chrono::milliseconds(5)); - std::thread recording_thread(record_event, event_id); - - waiting_thread.join(); - recording_thread.join(); -} - -TEST(Synchronization, WaitAndRecordEventMany) { - const size_t event_count = 16; - for (int i = 0; i < 8; ++i) { - std::thread thread_pool[2 * event_count]; - for (int j = 0; j < static_cast(event_count); ++j) { - thread_pool[j] = std::thread(wait_event, j); - thread_pool[j + event_count] = std::thread(record_event, j); - } - for (size_t j = 0; j < event_count; ++j) { - thread_pool[j].join(); - thread_pool[j + event_count].join(); - } - } -} - TEST(GradientCheckerTest, ExpandGrad) { float max_error; GradientChecker gradient_checker; diff --git a/orttraining/orttraining/test/training_ops/cpu/tensor/gather_nd_grad_op_test.cc b/orttraining/orttraining/test/training_ops/cpu/tensor/gather_nd_grad_op_test.cc new file mode 100644 index 0000000000..02c32a0827 --- /dev/null +++ b/orttraining/orttraining/test/training_ops/cpu/tensor/gather_nd_grad_op_test.cc @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/common/tensor_op_test_utils.h" + +namespace onnxruntime { +namespace test { + +#ifdef USE_CUDA +TEST(GatherNDGradOpTest, GatherNDGrad_slice_float_int64_t_batch_dims_1) { + OpTester test("GatherNDGrad", 1, kMSDomain); + test.AddAttribute("batch_dims", 0); + test.AddInput("shape", {3}, {2LL, 2LL, 3LL}); + test.AddInput("indices", {2, 2}, {0LL, 1LL, 1LL, 0LL}); + test.AddInput("update", {2, 3}, ValueRange(6, 1.0f)); + test.AddOutput("output", {2, 2, 3}, {0, 0, 0, 1, 2, 3, 4, 5, 6, 0, 0, 0}); + test.Run(); +} + +TEST(GatherNDGradOpTest, GatherNDGrad_slice_double_int32_t_batch_dims_3) { + if (!HasCudaEnvironment(600 /*min_cuda_architecture*/)) return; + + OpTester test("GatherNDGrad", 1, kMSDomain); + test.AddAttribute("batch_dims", 1); + test.AddInput("shape", {3}, {2LL, 2LL, 3LL}); + test.AddInput("indices", {2, 1, 1}, {1LL, 0LL}); + test.AddInput("update", {2, 3}, ValueRange(6, 1.0)); + test.AddOutput("output", {2, 2, 3}, {0, 0, 0, 1, 2, 3, 4, 5, 6, 0, 0, 0}); + test.Run(); +} + + +TEST(GatherNDGradOpTest, GatherNDGrad_slice_half_int32_t_batch_dims_3) { + if (!HasCudaEnvironment(600 /*min_cuda_architecture*/)) return; + + OpTester test("GatherNDGrad", 1, kMSDomain); + test.AddAttribute("batch_dims", 1); + test.AddInput("shape", {3}, {2LL, 2LL, 3LL}); + test.AddInput("indices", {2, 1, 1}, {1LL, 0LL}); + std::vector updates_f = ValueRange(6, 1.0f); + std::vector outputs_f({0, 0, 0, 1, 2, 3, 4, 5, 6, 0, 0, 0}); + std::vector updates(6); + std::vector outputs(12); + ConvertFloatToMLFloat16(updates_f.data(), updates.data(), 6); + ConvertFloatToMLFloat16(outputs_f.data(), outputs.data(), 12); + test.AddInput("update", {2, 3}, updates); + test.AddOutput("output", {2, 2, 3}, outputs); + test.Run(); +} + +TEST(GatherNDGradOpTest, GatherNDGrad_batch_dims_of_2) { + OpTester test("GatherNDGrad", 1, kMSDomain); + test.AddAttribute("batch_dims", 2); + test.AddInput("shape", {4}, {2, 2, 2, 3}); + test.AddInput( + "indices", {2, 2, 1}, + { + 1, // batch 0 + 1, // batch 1 + 0, // batch 2 + 1, // batch 3 + }); + test.AddInput("update", {2, 2, 3}, ValueRange(12)); + test.AddOutput( + "output", {2, 2, 2, 3}, + { + 0, 0, 0, 0, 1, 2, // batch 0 + 0, 0, 0, 3, 4, 5, // batch 1 + 6, 7, 8, 0, 0, 0, // batch 2 + 0, 0, 0, 9, 10, 11, // batch 3 + }); + test.Run(); +} +#endif + +} // namespace test +} // namespace onnxruntime diff --git a/orttraining/orttraining/test/training_ops/cpu/tensor/gather_nd_op_test.cc b/orttraining/orttraining/test/training_ops/cpu/tensor/gather_nd_op_test.cc deleted file mode 100644 index 0d01d27299..0000000000 --- a/orttraining/orttraining/test/training_ops/cpu/tensor/gather_nd_op_test.cc +++ /dev/null @@ -1,343 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "gtest/gtest.h" -#include "test/providers/provider_test_utils.h" -#include "test/common/cuda_op_test_utils.h" - -namespace onnxruntime { -namespace test { - -namespace { -// Returns a vector of `count` values which start at `start` and change by increments of `step`. -template -std::vector ValueRange( - size_t count, T start = static_cast(0), T step = static_cast(1)) { - std::vector result; - result.reserve(count); - T curr = start; - for (size_t i = 0; i < count; ++i) { - result.emplace_back(curr); - curr += step; - } - return result; -} -} // namespace - -TEST(GatherNDOpTest, GatherND_scalar_string_int32) { - OpTester test1("GatherND", 1, onnxruntime::kOnnxDomain); - test1.AddInput("data", {2, 2}, {"h", "k", "o", "z"}); - test1.AddInput("indices", {2}, {0, 1}); - test1.AddOutput("output", {}, {"k"}); - test1.Run(); - - OpTester test2("GatherND", 1, onnxruntime::kOnnxDomain); - test2.AddInput("data", {6}, {"h", "k", "o", "z", "l", "t"}); - test2.AddInput("indices", {1}, {3}); - test2.AddOutput("output", {}, {"z"}); - test2.Run(); - - OpTester test3("GatherND", 1, onnxruntime::kOnnxDomain); - test3.AddInput("data", {3, 2}, {"h", "k", "o", "z", "l", "t"}); - test3.AddInput("indices", {2}, {2, 1}); - test3.AddOutput("output", {}, {"t"}); - test3.Run(); -} - -TEST(GatherNDOpTest, GatherND_matrix_int64_int64) { - OpTester test("GatherND", 1, onnxruntime::kOnnxDomain); - test.AddInput("data", {2, 2}, {0LL, 1LL, 2LL, 3LL}); - test.AddInput("indices", {2, 2}, {0LL, 0LL, 1LL, 1LL}); - test.AddOutput("output", {2}, {0LL, 3LL}); - test.Run(); -} - -TEST(GatherNDOpTest, GatherND_matrix_string_int64) { - OpTester test("GatherND", 1, onnxruntime::kOnnxDomain); - test.AddInput("data", {2, 2}, {"a", "b", "c", "d"}); - test.AddInput("indices", {2, 2}, {0LL, 0LL, 1LL, 1LL}); - test.AddOutput("output", {2}, {"a", "d"}); - test.Run(); -} - -TEST(GatherNDOpTest, GatherND_matrix_int64_int32) { - OpTester test("GatherND", 1, onnxruntime::kOnnxDomain); - test.AddInput("data", {2, 2}, {0LL, 1LL, 2LL, 3LL}); - test.AddInput("indices", {2, 2}, {0, 0, 1, 1}); - test.AddOutput("output", {2}, {0LL, 3LL}); - test.Run(); -} - -TEST(GatherNDOpTest, GatherND_matrix_string_int32) { - OpTester test1("GatherND", 1, onnxruntime::kOnnxDomain); - test1.AddInput("data", {2, 2, 2}, {"egg", "dance", "air", "bob", "terry", "smart", "laugh", "kite"}); - test1.AddInput("indices", {2, 1, 2}, {0, 1, 1, 0}); - test1.AddOutput("output", {2, 1, 2}, {"air", "bob", "terry", "smart"}); - test1.Run(); - - OpTester test2("GatherND", 1, onnxruntime::kOnnxDomain); - test2.AddInput("data", {3, 3}, {"egg", "dance", "air", "bob", "terry", "smart", "laugh", "kite", "hop"}); - test2.AddInput("indices", {3, 2}, {2, 1, 1, 0, 0, 1}); - test2.AddOutput("output", {3}, {"kite", "bob", "dance"}); - test2.Run(); -} - -TEST(GatherNDOpTest, GatherND_slice_float_int64_t) { - OpTester test("GatherND", 1, onnxruntime::kOnnxDomain); - test.AddInput("data", {2, 2}, {0.0f, 0.1f, 0.2f, 0.3f}); - test.AddInput("indices", {2, 1}, {1LL, 0LL}); - test.AddOutput("output", {2, 2}, {0.2f, 0.3f, 0.0f, 0.1f}); - test.Run(); -} - -TEST(GatherNDOpTest, GatherND_slice_float_int64_t_axis_0) { - OpTester test("GatherND", 1, onnxruntime::kOnnxDomain); - test.AddAttribute("axis", 0); - test.AddInput("data", {2, 3, 4}, ValueRange(24, 1.0f)); - test.AddInput("indices", {3, 2, 2}, {0LL, 1LL, 0LL, 2LL, 1LL, 0LL, 0LL, 0LL, 1LL, 1LL, 1LL, 2LL}); - test.AddOutput("output", {3, 2, 4}, {5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 1.0, 2.0, 3.0, 4.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}); - test.Run(); -} - -TEST(GatherNDOpTest, GatherND_slice_float_int64_t_axis_1) { - OpTester test("GatherND", 1, onnxruntime::kOnnxDomain); - test.AddAttribute("axis", 1); - test.AddInput("data", {2, 3, 4}, ValueRange(24, 1.0f)); - test.AddInput("indices", {2, 2, 2}, {0LL, 1LL, 0LL, 2LL, 1LL, 0LL, 0LL, 0LL}); - test.AddOutput("output", {2, 2}, {2.0, 3.0, 17.0, 13.0}); - test.Run(); -} - -TEST(GatherNDOpTest, GatherND_slice_float_int32_t_axis_2) { - OpTester test("GatherND", 1, onnxruntime::kOnnxDomain); - test.AddAttribute("axis", 1); - test.AddInput("data", {2, 2, 2}, ValueRange(8, 0.0f, 0.1f)); - test.AddInput("indices", {2, 1}, {1LL, 0LL}); - test.AddOutput("output", {2, 2}, {0.2f, 0.3f, 0.4f, 0.5f}); - test.Run(); -} - -#ifdef USE_CUDA -#if __CUDA_ARCH__ >= 600 -TEST(GatherNDOpTest, GatherND_slice_double_int64_t_axis_3) { - OpTester test("GatherND", 1, onnxruntime::kOnnxDomain); - test.AddAttribute("axis", 1); - test.AddInput("data", {2, 2, 2}, ValueRange(8, 0.0f, 0.1f)); - test.AddInput("indices", {2, 1, 1}, {1LL, 0LL}); - test.AddOutput("output", {2, 1, 2}, {0.2f, 0.3f, 0.4f, 0.5f}); - test.Run(); -} - -TEST(GatherNDOpTest, GatherND_slice_double_int32_t) { - OpTester test("GatherND", 1, onnxruntime::kOnnxDomain); - test.AddInput("data", {2, 2}, {0.0f, 0.1f, 0.2f, 0.3f}); - test.AddInput("indices", {2, 1}, {1LL, 0LL}); - test.AddOutput("output", {2, 2}, {0.2f, 0.3f, 0.0f, 0.1f}); - test.Run(); -} -#endif -#endif - -TEST(GatherNDOpTest, GatherND_slice_float_int64_t_axis_4) { - OpTester test("GatherND", 1, onnxruntime::kOnnxDomain); - test.AddAttribute("axis", 1); - test.AddInput("data", {2, 2, 2}, ValueRange(8, 0.0f, 0.1f)); - test.AddInput("indices", {2, 1, 2}, {1LL, 0LL, 0LL, 1LL}); - test.AddOutput("output", {2, 1}, {0.2f, 0.5f}); - test.Run(); -} - -TEST(GatherNDOpTest, GatherND_3tensor_int64) { - OpTester test1("GatherND", 1, onnxruntime::kOnnxDomain); - test1.AddInput("data", {2, 2, 2}, ValueRange(8)); - test1.AddInput("indices", {2, 2}, {0LL, 1LL, 1LL, 0LL}); - test1.AddOutput("output", {2, 2}, {2LL, 3LL, 4LL, 5LL}); - test1.Run(); - - OpTester test2("GatherND", 1, onnxruntime::kOnnxDomain); - test2.AddInput("data", {2, 2, 2}, ValueRange(8)); - test2.AddInput("indices", {2, 3}, {0, 0, 1, 1, 0, 1}); - test2.AddOutput("output", {2}, {1, 5}); - test2.Run(); - - OpTester test3("GatherND", 1, onnxruntime::kOnnxDomain); - test3.AddInput("data", {2, 2, 2}, ValueRange(8)); - test3.AddInput("indices", {1, 1}, {1LL}); - test3.AddOutput("output", {1, 2, 2}, {4, 5, 6, 7}); - test3.Run(); -} - -TEST(GatherNDOpTest, GatherND_batched_index_int64) { - OpTester test("GatherND", 1, onnxruntime::kOnnxDomain); - test.AddInput("data", {2, 2}, {0LL, 1LL, 2LL, 3LL}); - test.AddInput("indices", {2, 1, 2}, {0LL, 0LL, 0LL, 1LL}); - test.AddOutput("output", {2, 1}, {0LL, 1LL}); - test.Run(); -} - -TEST(GatherNDOpTest, GatherND_batched_index_bool_int64) { - OpTester test("GatherND", 1, onnxruntime::kOnnxDomain); - test.AddInput("data", {2, 2}, {true, false, false, true}); - test.AddInput("indices", {2, 1, 2}, {0LL, 0LL, 0LL, 1LL}); - test.AddOutput("output", {2, 1}, {true, false}); - test.Run(); -} - -TEST(GatherNDOpTest, GatherND_sliced_index_int64) { - OpTester test("GatherND", 1, onnxruntime::kOnnxDomain); - test.AddInput("data", {2, 2}, {0LL, 1LL, 2LL, 3LL}); - test.AddInput("indices", {2, 1, 1}, {1LL, 0LL}); - test.AddOutput("output", {2, 1, 2}, {2LL, 3LL, 0LL, 1LL}); - test.Run(); -} - -TEST(GatherNDOpTest, GatherND_sliced_index_string_int32) { - OpTester test("GatherND", 1, onnxruntime::kOnnxDomain); - test.AddInput("data", {2, 2}, {"ab", "cde", "f", "ghi"}); - test.AddInput("indices", {2, 1, 1}, {1LL, 0LL}); - test.AddOutput("output", {2, 1, 2}, {"f", "ghi", "ab", "cde"}); - test.Run(); -} - -TEST(GatherNDOpTest, GatherND_batched_3tensor_int64) { - OpTester test1("GatherND", 1, onnxruntime::kOnnxDomain); - test1.AddInput("data", {2, 2, 2}, ValueRange(8)); - test1.AddInput("indices", {2, 2, 2}, {0LL, 1LL, 1LL, 0LL, 0LL, 0LL, 1LL, 1LL}); - test1.AddOutput("output", {2, 2, 2}, {2, 3, 4, 5, 0, 1, 6, 7}); - test1.Run(); - - OpTester test2("GatherND", 1, onnxruntime::kOnnxDomain); - test2.AddInput("data", {2, 2, 2}, ValueRange(8)); - test2.AddInput("indices", {2, 2, 3}, {0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0}); - test2.AddOutput("output", {2, 2}, {1, 5, 3, 6}); - test2.Run(); - - OpTester test3("GatherND", 1, onnxruntime::kOnnxDomain); - test3.AddInput("data", {2, 2, 2}, ValueRange(8)); - test3.AddInput("indices", {2, 1, 1}, {1, 0}); - test3.AddOutput("output", {2, 1, 2, 2}, {4LL, 5LL, 6LL, 7LL, 0LL, 1LL, 2LL, 3LL}); - test3.Run(); -} - -#ifdef USE_CUDA -TEST(GatherNDOpTest, GatherNDGrad_slice_float_int64_t_axis_1) { - OpTester test("GatherNDGrad", 1, onnxruntime::kOnnxDomain); - test.AddAttribute("axis", 0); - test.AddInput("shape", {3}, {2LL, 2LL, 3LL}); - test.AddInput("indices", {2, 2}, {0LL, 1LL, 1LL, 0LL}); - test.AddInput("update", {2, 3}, ValueRange(6, 1.0f)); - test.AddOutput("output", {2, 2, 3}, {0, 0, 0, 1, 2, 3, 4, 5, 6, 0, 0, 0}); - test.Run(); -} -#endif - -#ifdef USE_CUDA -TEST(GatherNDOpTest, GatherNDGrad_slice_double_int32_t_axis_3) { - if (!HasCudaEnvironment(600 /*min_cuda_architecture*/)) return; - - OpTester test("GatherNDGrad", 1, onnxruntime::kOnnxDomain); - test.AddAttribute("axis", 1); - test.AddInput("shape", {3}, {2LL, 2LL, 3LL}); - test.AddInput("indices", {2, 1, 1}, {1LL, 0LL}); - test.AddInput("update", {2, 3}, ValueRange(6, 1.0)); - test.AddOutput("output", {2, 2, 3}, {0, 0, 0, 1, 2, 3, 4, 5, 6, 0, 0, 0}); - test.Run(); -} - -TEST(GatherNDOpTest, GatherND_slice_double_int64_t_axis_3) { - if (!HasCudaEnvironment(600 /*min_cuda_architecture*/)) return; - - OpTester test("GatherND", 1, onnxruntime::kOnnxDomain); - test.AddAttribute("axis", 1); - test.AddInput("data", {2, 2, 2}, ValueRange(8, 0.0, 0.1)); - test.AddInput("indices", {2, 1, 1}, {1LL, 0LL}); - test.AddOutput("output", {2, 1, 2}, {0.2f, 0.3f, 0.4f, 0.5f}); - test.Run(); -} - -TEST(GatherNDOpTest, GatherNDGrad_slice_half_int32_t_axis_3) { - if (!HasCudaEnvironment(600 /*min_cuda_architecture*/)) return; - - OpTester test("GatherNDGrad", 1, onnxruntime::kOnnxDomain); - test.AddAttribute("axis", 1); - test.AddInput("shape", {3}, {2LL, 2LL, 3LL}); - test.AddInput("indices", {2, 1, 1}, {1LL, 0LL}); - std::vector updates_f = ValueRange(6, 1.0f); - std::vector outputs_f({0, 0, 0, 1, 2, 3, 4, 5, 6, 0, 0, 0}); - std::vector updates(6); - std::vector outputs(12); - ConvertFloatToMLFloat16(updates_f.data(), updates.data(), 6); - ConvertFloatToMLFloat16(outputs_f.data(), outputs.data(), 12); - test.AddInput("update", {2, 3}, updates); - test.AddOutput("output", {2, 2, 3}, outputs); - test.Run(); -} - -TEST(GatherNDOpTest, GatherND_slice_half_int32_t) { - if (!HasCudaEnvironment(600 /*min_cuda_architecture*/)) return; - - OpTester test("GatherND", 1, onnxruntime::kOnnxDomain); - std::vector data_f({0.0f, 0.1f, 0.2f, 0.3f}); - std::vector outputs_f({0.2f, 0.3f, 0.0f, 0.1f}); - std::vector data(4); - std::vector outputs(4); - ConvertFloatToMLFloat16(data_f.data(), data.data(), 4); - ConvertFloatToMLFloat16(outputs_f.data(), outputs.data(), 4); - test.AddInput("data", {2, 2}, data); - test.AddInput("indices", {2, 1}, {1LL, 0LL}); - test.AddOutput("output", {2, 2}, outputs); - test.Run(); -} -#endif - -#ifdef USE_CUDA -TEST(GatherNDOpTest, GatherND_axis_of_2) { - OpTester test("GatherND", 1, kOnnxDomain); - test.AddAttribute("axis", 2); - test.AddInput("data", {2, 2, 2, 2, 3}, ValueRange(48)); - test.AddInput( - "indices", {2, 2, 1, 2}, - { - 0, 0, // batch 0 - 1, 0, // batch 1 - 1, 1, // batch 2 - 0, 1, // batch 3 - }); - test.AddOutput( - "output", {2, 2, 1, 3}, - { - 0, 1, 2, // batch 0 - 18, 19, 20, // batch 1 - 33, 34, 35, // batch 2 - 39, 40, 41, // batch 3 - }); - test.Run(); -} - -TEST(GatherNDOpTest, GatherNDGrad_axis_of_2) { - OpTester test("GatherNDGrad", 1, kOnnxDomain); - test.AddAttribute("axis", 2); - test.AddInput("shape", {4}, {2, 2, 2, 3}); - test.AddInput( - "indices", {2, 2, 1}, - { - 1, // batch 0 - 1, // batch 1 - 0, // batch 2 - 1, // batch 3 - }); - test.AddInput("update", {2, 2, 3}, ValueRange(12)); - test.AddOutput( - "output", {2, 2, 2, 3}, - { - 0, 0, 0, 0, 1, 2, // batch 0 - 0, 0, 0, 3, 4, 5, // batch 1 - 6, 7, 8, 0, 0, 0, // batch 2 - 0, 0, 0, 9, 10, 11, // batch 3 - }); - test.Run(); -} -#endif - -} // namespace test -} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cpu/controlflow/common.h b/orttraining/orttraining/training_ops/cpu/controlflow/common.h new file mode 100644 index 0000000000..c6ff14689f --- /dev/null +++ b/orttraining/orttraining/training_ops/cpu/controlflow/common.h @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace onnxruntime { +namespace contrib { + +template +std::vector> AliasRange(int start, int end) { + std::vector> aliases; + for (int i = start; i < end; i++) { + aliases.push_back(std::pair(input_start + i, output_start + i)); + } + return aliases; +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cpu/controlflow/event_pool.cc b/orttraining/orttraining/training_ops/cpu/controlflow/event_pool.cc index a39ab399c2..e9bb0c51c3 100644 --- a/orttraining/orttraining/training_ops/cpu/controlflow/event_pool.cc +++ b/orttraining/orttraining/training_ops/cpu/controlflow/event_pool.cc @@ -6,8 +6,16 @@ namespace onnxruntime { namespace contrib { +void OrtEventPool::CheckRange(const int64_t id) const { + ORT_ENFORCE( + id >= 0 && id < MaxNumItems, + "Got id ", id, + ". It should be in a range from 0 to ", + MaxNumItems, "."); +} + void OrtEventPool::SignalEvent(int64_t id) { - ORT_ENFORCE(id >= 0 && id < MaxNumItems); + CheckRange(id); std::unique_lock lock(pool_[id].mutex); pool_[id].signaled.store(true); lock.unlock(); @@ -15,18 +23,18 @@ void OrtEventPool::SignalEvent(int64_t id) { }; void OrtEventPool::ResetEvent(int64_t id) { - ORT_ENFORCE(id >= 0 && id < MaxNumItems); + CheckRange(id); std::lock_guard guard(pool_[id].mutex); pool_[id].signaled.store(false); }; bool OrtEventPool::QueryEvent(int64_t id) const { - ORT_ENFORCE(id >= 0 && id < MaxNumItems); + CheckRange(id); return pool_[id].signaled.load(); } void OrtEventPool::WaitEvent(int64_t id) const { - ORT_ENFORCE(id >= 0 && id < MaxNumItems); + CheckRange(id); std::unique_lock lock(pool_[id].mutex); pool_[id].cv.wait(lock, [this, id] { return pool_[id].signaled.load(); }); }; diff --git a/orttraining/orttraining/training_ops/cpu/controlflow/event_pool.h b/orttraining/orttraining/training_ops/cpu/controlflow/event_pool.h index 511a5da234..68e9b95abb 100644 --- a/orttraining/orttraining/training_ops/cpu/controlflow/event_pool.h +++ b/orttraining/orttraining/training_ops/cpu/controlflow/event_pool.h @@ -34,6 +34,8 @@ class OrtEventPool final { OrtEventPool(const OrtEventPool&) = delete; OrtEventPool& operator=(const OrtEventPool&) = delete; + void CheckRange(const int64_t event_id) const; + struct Item { std::atomic signaled; mutable std::mutex mutex; @@ -43,9 +45,11 @@ class OrtEventPool final { signaled.store(false); } }; + enum { MaxNumItems = 4096 }; + Item pool_[MaxNumItems]; }; diff --git a/orttraining/orttraining/training_ops/cpu/controlflow/record.cc b/orttraining/orttraining/training_ops/cpu/controlflow/record.cc index c36a23e9e7..eb305a7c27 100644 --- a/orttraining/orttraining/training_ops/cpu/controlflow/record.cc +++ b/orttraining/orttraining/training_ops/cpu/controlflow/record.cc @@ -1,19 +1,17 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "record.h" +#include "orttraining/training_ops/cpu/controlflow/record.h" #include "core/providers/cpu/tensor/utils.h" +#include "common.h" namespace onnxruntime { namespace contrib { -template -std::vector> AliasRange(int start, int end) { - std::vector> aliases; - for (int i = start; i < end; i++) { - aliases.push_back(std::pair(input_start + i, output_start + i)); - } - return aliases; +void record_event_in_tensor(const Tensor& event_id_tensor) { + const int64_t event_id = *event_id_tensor.template Data(); + ORT_ENFORCE(event_id != -1, "-1 is reserved for skip wait, so cannot be used in RecordEvent"); + OrtEventPool::GetInstance().SignalEvent(event_id); } ONNX_OPERATOR_KERNEL_EX( @@ -28,12 +26,8 @@ ONNX_OPERATOR_KERNEL_EX( RecordEvent); Status RecordEvent::Compute(OpKernelContext* ctx) const { - const Tensor* event_id_tensor = ctx->Input(0); - const int64_t event_id = *event_id_tensor->template Data(); - - ORT_RETURN_IF_NOT(event_id != -1, "-1 is reserved for skip wait, so cannot be used in RecordEvent"); - - OrtEventPool::GetInstance().SignalEvent(event_id); + // Pass event-id tensor to event-recording helper function. + record_event_in_tensor(*ctx->Input(0)); for (int i_out = 0; i_out < ctx->OutputCount(); ++i_out) { const Tensor* X = ctx->Input(i_out + 1); diff --git a/orttraining/orttraining/training_ops/cpu/controlflow/record.h b/orttraining/orttraining/training_ops/cpu/controlflow/record.h index d4f02a612d..61fb28abba 100644 --- a/orttraining/orttraining/training_ops/cpu/controlflow/record.h +++ b/orttraining/orttraining/training_ops/cpu/controlflow/record.h @@ -9,6 +9,9 @@ namespace onnxruntime { namespace contrib { +// Record the event ID stored in the input tensor. +void record_event_in_tensor(const Tensor& event_id_tensor); + class RecordEvent final : public OpKernel { public: RecordEvent(const OpKernelInfo& info) : OpKernel(info) { } diff --git a/orttraining/orttraining/training_ops/cpu/controlflow/wait.cc b/orttraining/orttraining/training_ops/cpu/controlflow/wait.cc index b08ba65ddd..67f3128234 100644 --- a/orttraining/orttraining/training_ops/cpu/controlflow/wait.cc +++ b/orttraining/orttraining/training_ops/cpu/controlflow/wait.cc @@ -1,19 +1,23 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "wait.h" +#include "orttraining/training_ops/cpu/controlflow/wait.h" #include "core/providers/cpu/tensor/utils.h" +#include "common.h" namespace onnxruntime { namespace contrib { -template -std::vector> AliasRange(int start, int end) { - std::vector> aliases; - for (int i = start; i < end; i++) { - aliases.push_back(std::pair(input_start + i, output_start + i)); +void wait_event_in_tensor(const Tensor& event_id_tensor) { + const int64_t event_id = *event_id_tensor.template Data(); + // -1 is reserved to skip wait event + if (event_id != -1) { + // Wait the event to be recorded by a RecordEvent operator. + OrtEventPool::GetInstance().WaitEvent(event_id); + // BUGBUG: seems this would cause hang when a event is being waited more than once + // Destory the recorded event. + OrtEventPool::GetInstance().ResetEvent(event_id); } - return aliases; } ONNX_OPERATOR_KERNEL_EX( @@ -28,18 +32,7 @@ ONNX_OPERATOR_KERNEL_EX( WaitEvent); Status WaitEvent::Compute(OpKernelContext* ctx) const { - const Tensor* event_id_tensor = ctx->Input(0); - const int64_t event_id = *event_id_tensor->template Data(); - - // -1 is reserved to skip wait event - if (event_id != -1) { - // Wait the event to be recorded by a RecordEvent operator. - OrtEventPool::GetInstance().WaitEvent(event_id); - - // BUGBUG: seems this would cause hang when a event is being waited more than once - // Destory the recorded event. - OrtEventPool::GetInstance().ResetEvent(event_id); - } + wait_event_in_tensor(*ctx->Input(0)); for (int i_out = 0; i_out < ctx->OutputCount(); ++i_out) { const Tensor* X = ctx->Input(i_out + 1); diff --git a/orttraining/orttraining/training_ops/cpu/controlflow/wait.h b/orttraining/orttraining/training_ops/cpu/controlflow/wait.h index dff514880f..682aa4388e 100644 --- a/orttraining/orttraining/training_ops/cpu/controlflow/wait.h +++ b/orttraining/orttraining/training_ops/cpu/controlflow/wait.h @@ -2,8 +2,6 @@ // Licensed under the MIT License. #pragma once -#include -#include #include "core/common/common.h" #include "core/framework/op_kernel.h" #include "event_pool.h" @@ -11,6 +9,9 @@ namespace onnxruntime { namespace contrib { +// Wait for the event ID stored in the input tensor. +void wait_event_in_tensor(const Tensor& event_id_tensor); + class WaitEvent final : public OpKernel { public: WaitEvent(const OpKernelInfo& info) : OpKernel(info) { } diff --git a/orttraining/orttraining/training_ops/cpu_training_kernels.cc b/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc similarity index 98% rename from orttraining/orttraining/training_ops/cpu_training_kernels.cc rename to orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc index edde3947dc..c8783baff9 100644 --- a/orttraining/orttraining/training_ops/cpu_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "orttraining/training_ops/cpu_training_kernels.h" +#include "orttraining/training_ops/cpu/cpu_training_kernels.h" #include "core/graph/constants.h" namespace onnxruntime { @@ -17,7 +17,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, InPla class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ZeroGradient); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Group); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, GatherND); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SoftmaxCrossEntropy); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SoftmaxCrossEntropyGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, SparseSoftmaxCrossEntropy); @@ -106,7 +105,6 @@ Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/orttraining/orttraining/training_ops/cpu_training_kernels.h b/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.h similarity index 100% rename from orttraining/orttraining/training_ops/cpu_training_kernels.h rename to orttraining/orttraining/training_ops/cpu/cpu_training_kernels.h diff --git a/orttraining/orttraining/training_ops/cpu/tensor/gather_nd.cc b/orttraining/orttraining/training_ops/cpu/tensor/gather_nd.cc deleted file mode 100644 index 6e11910ea2..0000000000 --- a/orttraining/orttraining/training_ops/cpu/tensor/gather_nd.cc +++ /dev/null @@ -1,138 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "orttraining/training_ops/cpu/tensor/gather_nd.h" - -namespace onnxruntime { -namespace contrib { - -ONNX_OPERATOR_KERNEL_EX( - GatherND, - kOnnxDomain, - 1, - kCpuExecutionProvider, - KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::AllTensorTypes()) - .TypeConstraint("Tind", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), - GatherND); - -template -Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) const { - auto input_tensor = context->Input(0); - auto indice_tensor = context->Input(1); - ORT_ENFORCE(input_tensor != nullptr); - ORT_ENFORCE(indice_tensor != nullptr); - - auto input_shape = input_tensor->Shape(); - auto indice_shape = indice_tensor->Shape(); - if (indice_shape.NumDimensions() == 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "indices tensor must has rank larger than 0"); - } - - auto last_indice_dimension = indice_shape[indice_shape.NumDimensions() - 1] + axis_; - if (last_indice_dimension > static_cast(input_shape.NumDimensions())) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "last dimension of indices must not be larger than rank of input tensor"); - } - - std::vector shape(indice_shape.GetDims().begin(), - indice_shape.GetDims().end() - 1); - shape.insert(shape.end(), - input_shape.GetDims().begin() + last_indice_dimension, - input_shape.GetDims().end()); - auto output_tensor = context->Output(0, TensorShape(shape)); - std::vector element_counts(last_indice_dimension + axis_, 0LL); // Number of elements for each input dimension - -#ifdef USE_OPENMP -#pragma omp parallel for -#endif - for (int64_t i = 0; i < last_indice_dimension; ++i) { - element_counts[i] = input_shape.SizeFromDimension(i + 1); - } - - auto last_dim_size = indice_shape.SizeFromDimension(indice_shape.NumDimensions() - 1); -#ifdef USE_OPENMP -#pragma omp parallel for -#endif - for (int64_t i = axis_ - 1; i >= 0; --i) { - element_counts[last_indice_dimension + i] = indice_shape.SizeFromDimension(i + 1) / last_dim_size; - } - - int64_t err_indice = 0; - p.element_bytes = input_tensor->DataType()->Size(); - p.element_to_copy = input_shape.SizeFromDimension(last_indice_dimension); - p.bytes_to_copy = p.element_bytes * p.element_to_copy; - auto indice_offset = indice_tensor->Data(); - auto offset_count = indice_shape.Size() / (last_indice_dimension - axis_); // Times to copy - p.element_offsets.assign(offset_count, 0LL); - - if (input_tensor->DataType() == DataTypeImpl::GetType()) { - p.input_str_base = static_cast(input_tensor->DataRaw()); - p.output_str_base = static_cast(output_tensor->MutableDataRaw()); - } else { - p.input_base = static_cast(input_tensor->DataRaw()); - p.output_base = static_cast(output_tensor->MutableDataRaw()); - } - - //Compute the element_offset -#ifdef USE_OPENMP -#pragma omp parallel for -#endif - for (int64_t i = 0; i < offset_count; ++i) { - int64_t reminder = i; - for (int64_t j = 0; j < axis_; ++j) { - int64_t idx = reminder / element_counts[last_indice_dimension + j]; - p.element_offsets[i] += idx * element_counts[j]; - reminder -= (idx * element_counts[last_indice_dimension + j]); - } - for (int64_t j = axis_; j < last_indice_dimension; ++j) { - auto indice = *(indice_offset + i * (last_indice_dimension - axis_) + (j - axis_)); - if (indice < 0 || indice >= input_shape[j]) { - err_indice = indice; - } - p.element_offsets[i] += indice * element_counts[j]; - } - } - - return err_indice == 0 ? Status::OK() : ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid indice found, indice = ", err_indice); -} - -template Status GatherNDBase::PrepareForCompute(OpKernelContext*, Prepare&) const; -template Status GatherNDBase::PrepareForCompute(OpKernelContext*, Prepare&) const; - -Status GatherND::Compute(OpKernelContext* context) const { - Prepare p; - ORT_RETURN_IF_ERROR(context->Input(1)->DataType() == DataTypeImpl::GetType() ? PrepareForCompute(context, p) : PrepareForCompute(context, p)); - - return nullptr == p.input_str_base ? GatherNumber(p) : GatherString(p); -} - -Status GatherND::GatherNumber(const Prepare& p) const { -#ifdef USE_OPENMP -#pragma omp parallel for -#endif - for (int64_t i = 0; i < static_cast(p.element_offsets.size()); ++i) { - memcpy(p.output_base + i * p.bytes_to_copy, - p.input_base + p.element_offsets[i] * p.element_bytes, - p.bytes_to_copy); - } - - return Status::OK(); -} - -Status GatherND::GatherString(const Prepare& p) const { -#ifdef USE_OPENMP -#pragma omp parallel for -#endif - for (int64_t i = 0; i < static_cast(p.element_offsets.size()); ++i) { - for (int64_t j = 0; j < static_cast(p.element_to_copy); ++j) { - p.output_str_base[i * p.element_to_copy + j] = p.input_str_base[p.element_offsets[i] + j]; - } - } - - return Status::OK(); -} - -} // namespace contrib -} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cpu/tensor/gather_nd.h b/orttraining/orttraining/training_ops/cpu/tensor/gather_nd.h deleted file mode 100644 index 2b1caddad9..0000000000 --- a/orttraining/orttraining/training_ops/cpu/tensor/gather_nd.h +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/common.h" -#include "core/framework/op_kernel.h" -#include "core/platform/threadpool.h" - -namespace onnxruntime { -namespace contrib { - -class GatherNDBase { - protected: - struct Prepare { - const uint8_t* input_base; - const std::string* input_str_base; - uint8_t* output_base; - std::string* output_str_base; - uint64_t bytes_to_copy; - uint64_t element_bytes; - uint64_t element_to_copy; - std::vector element_offsets; - - Prepare() : input_base(nullptr), - input_str_base(nullptr), - output_base(nullptr), - output_str_base(nullptr), - bytes_to_copy(0), - element_bytes(0), - element_to_copy(0), - element_offsets(0) {} - }; // struct Prepare - - template - Status PrepareForCompute(OpKernelContext* context, Prepare& p) const; - int64_t axis_; -}; // class GatherNDBase - -class GatherND final : public OpKernel, protected GatherNDBase { - public: - explicit GatherND(const OpKernelInfo& info) : OpKernel(info) { - info.GetAttrOrDefault("axis", &axis_, static_cast(0)); - } - Status Compute(OpKernelContext* context) const override; - - private: - Status GatherNumber(const Prepare& p) const; - Status GatherString(const Prepare& p) const; -}; - -} // namespace contrib -} // namespace onnxruntime \ No newline at end of file diff --git a/orttraining/orttraining/training_ops/cuda/controlflow/record.cc b/orttraining/orttraining/training_ops/cuda/controlflow/record.cc new file mode 100644 index 0000000000..2f8bb82a2e --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/controlflow/record.cc @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/training_ops/cuda/controlflow/record.h" +#include "core/providers/cpu/tensor/utils.h" +// Include RecordEvent's utility functions shared by CPU and GPU implementations. +#include "orttraining/training_ops/cpu/controlflow/common.h" +// Include event mechanism shared by CPU and GPU implementations. +#include "orttraining/training_ops/cpu/controlflow/event_pool.h" +#include "orttraining/training_ops/cpu/controlflow/record.h" + +namespace onnxruntime { +namespace cuda { + +ONNX_OPERATOR_KERNEL_EX( + RecordEvent, + kMSDomain, + 1, + kCudaExecutionProvider, + KernelDefBuilder() + .InputMemoryType(0) /* Keep EventIdentifier in CPU */ + .TypeConstraint("TInt64", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .Alias(onnxruntime::contrib::AliasRange<1, 0>(0, 1024)), + RecordEvent); + +Status RecordEvent::ComputeInternal(OpKernelContext* ctx) const { + // Reuse CPU helper to record event because event tensor is a CPU tensor. + onnxruntime::contrib::record_event_in_tensor(*ctx->Input(0)); + + for (int i_out = 0; i_out < ctx->OutputCount(); ++i_out) { + // This iteration copies (i-1)-th input to i-th output. + const Tensor* X = ctx->Input(i_out + 1); + const TensorShape& data_shape = X->Shape(); + Tensor* Y = ctx->Output(i_out, data_shape); + CopyTensor(*X, *Y); + } + + return Status::OK(); +} + +} // namespace cuda +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/controlflow/record.h b/orttraining/orttraining/training_ops/cuda/controlflow/record.h new file mode 100644 index 0000000000..0063af48f1 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/controlflow/record.h @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cudnn_common.h" + +namespace onnxruntime { +namespace cuda { + +class RecordEvent final : public CudaKernel { +public: + RecordEvent(const OpKernelInfo& info) : CudaKernel(info) { } + Status ComputeInternal(OpKernelContext* context) const override; +}; + +} // namespace cuda +} // namespace onnxruntime \ No newline at end of file diff --git a/orttraining/orttraining/training_ops/cuda/controlflow/wait.cc b/orttraining/orttraining/training_ops/cuda/controlflow/wait.cc new file mode 100644 index 0000000000..f90177c51a --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/controlflow/wait.cc @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/training_ops/cuda/controlflow/wait.h" +#include "core/providers/cpu/tensor/utils.h" +// Include RecordEvent's utility functions shared by CPU and GPU implementations. +#include "orttraining/training_ops/cpu/controlflow/common.h" +// Include event mechanism shared by CPU and GPU implementations. +#include "orttraining/training_ops/cpu/controlflow/event_pool.h" +#include "orttraining/training_ops/cpu/controlflow/wait.h" + +namespace onnxruntime { +namespace cuda { + +ONNX_OPERATOR_KERNEL_EX( + WaitEvent, + kMSDomain, + 1, + kCudaExecutionProvider, + KernelDefBuilder() + .InputMemoryType(0) /* CPU variable */ + .TypeConstraint("TInt64", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .Alias(onnxruntime::contrib::AliasRange<1, 0>(0, 1024)), + WaitEvent); + +Status WaitEvent::ComputeInternal(OpKernelContext* ctx) const { + // Reuse CPU helper to wait event because event tensor is a CPU tensor. + onnxruntime::contrib::wait_event_in_tensor(*ctx->Input(0)); + + for (int i_out = 0; i_out < ctx->OutputCount(); ++i_out) { + // This iteration copies (i-1)-th input to i-th output. + const Tensor* X = ctx->Input(i_out + 1); + const TensorShape& data_shape = X->Shape(); + Tensor* Y = ctx->Output(i_out, data_shape); + CopyTensor(*X, *Y); + } + + return Status::OK(); +} + +} // namespace cuda +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/controlflow/wait.h b/orttraining/orttraining/training_ops/cuda/controlflow/wait.h new file mode 100644 index 0000000000..b4a687fef6 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/controlflow/wait.h @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cudnn_common.h" + +namespace onnxruntime { +namespace cuda { + +class WaitEvent final : public CudaKernel { +public: + WaitEvent(const OpKernelInfo& info) : CudaKernel(info) { } + Status ComputeInternal(OpKernelContext* context) const override; +}; + +} // namespace cuda +} // namespace onnxruntime \ No newline at end of file diff --git a/orttraining/orttraining/training_ops/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc similarity index 97% rename from orttraining/orttraining/training_ops/cuda_training_kernels.cc rename to orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc index 76c04462ae..537864f1a1 100644 --- a/orttraining/orttraining/training_ops/cuda_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc @@ -88,10 +88,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_MLFloat16, DropoutGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_float, DropoutGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_double, DropoutGrad); + +// TODO: decprecate GatherND-1 after updating training models to opset-12 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int64_t, GatherND); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int32_t, GatherND); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int64_t, GatherNDGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int32_t, GatherNDGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, GatherNDGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DivGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, DivGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DivGrad); @@ -131,6 +131,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Send class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Recv); #endif +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, RecordEvent); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, WaitEvent); + #ifdef USE_NCCL class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, NcclAllReduce); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, NcclAllGather); @@ -209,10 +212,10 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + // TODO: decprecate GatherND-1 after updating training models to opset-12 BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, // BuildKernelCreateInfo, @@ -263,6 +266,9 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, #endif + BuildKernelCreateInfo, + BuildKernelCreateInfo, + #ifdef USE_NCCL BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/orttraining/orttraining/training_ops/cuda_training_kernels.h b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.h similarity index 100% rename from orttraining/orttraining/training_ops/cuda_training_kernels.h rename to orttraining/orttraining/training_ops/cuda/cuda_training_kernels.h diff --git a/orttraining/orttraining/training_ops/cuda/nn/dropout.cc b/orttraining/orttraining/training_ops/cuda/nn/dropout.cc index bced973f88..44e39f8af3 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/dropout.cc +++ b/orttraining/orttraining/training_ops/cuda/nn/dropout.cc @@ -37,7 +37,6 @@ REGISTER_KERNEL_TYPED(Dropout, kOnnxDomain, 12, double, double, 1) template Status Dropout::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; - typedef typename ToCudaType::MappedType CudaT2; //Get X_data const Tensor* X = context->Input(0); @@ -68,7 +67,7 @@ Status Dropout::ComputeInternal(OpKernelContext* context) const { "T2 must be float16 or float or double"); if (ratio) { - ratio_data = static_cast(*reinterpret_cast(ratio->template Data())); + ratio_data = static_cast(*(ratio->template Data())); } else { ratio_data = default_ratio_; } @@ -112,7 +111,7 @@ Status DropoutGrad::ComputeInternal(OpKernelContext* context) const { "T2 must be float16 or float or double"); if (ratio) { - ratio_data = static_cast(*reinterpret_cast(ratio->template Data())); + ratio_data = static_cast(*(ratio->template Data())); } else { ratio_data = default_ratio_; } diff --git a/orttraining/orttraining/training_ops/cuda/tensor/gather_grad.cc b/orttraining/orttraining/training_ops/cuda/tensor/gather_grad.cc index 70c9c58247..56f0e4e984 100644 --- a/orttraining/orttraining/training_ops/cuda/tensor/gather_grad.cc +++ b/orttraining/orttraining/training_ops/cuda/tensor/gather_grad.cc @@ -3,7 +3,6 @@ #include "orttraining/training_ops/cuda/tensor/gather_grad.h" #include "orttraining/training_ops/cuda/tensor/gather_grad_impl.h" -#include "orttraining/training_ops/cuda/tensor/thrustallocator.h" #include "core/providers/common.h" namespace onnxruntime { diff --git a/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_gard_impl.cu b/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_gard_impl.cu new file mode 100644 index 0000000000..8c37836705 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_gard_impl.cu @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/tensor/gather_nd_impl.h" + +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/atomic/common.cuh" + +namespace onnxruntime { +namespace cuda { + +template +__global__ void _GatherNDGradKernel( + const size_t num_slices, + const T* update_data, + T* output_data, + const size_t slice_size, + const int64_t* slice_offsets) { + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(i, num_slices * slice_size); + uint64_t slice_offset = slice_offsets[i / slice_size]; + size_t j = i % slice_size; + atomic_add(output_data + slice_offset + j, update_data[i]); +}; + +template +void GatherNDGradImpl( + const size_t num_slices, + const void* update_data, + void* output_data, + const size_t slice_size, + const int64_t* input_slice_offsets_data) { + const auto blocks_per_grid = CeilDiv(num_slices * slice_size, GridDim::maxThreadsPerBlock); + _GatherNDGradKernel<<>>( + num_slices, static_cast(update_data), static_cast(output_data), slice_size, input_slice_offsets_data); +} + +#define SPECIALIZED_GRAD_IMPL(T) \ + template void GatherNDGradImpl(const size_t num_slices, const void* update_data, void* output_data, const size_t slice_size, const int64_t* input_slice_offsets_data) + +SPECIALIZED_GRAD_IMPL(float); +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 +SPECIALIZED_GRAD_IMPL(half); +SPECIALIZED_GRAD_IMPL(double); +#endif + +} // namespace cuda +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_grad.cc b/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_grad.cc new file mode 100644 index 0000000000..d7878cf0a6 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_grad.cc @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/training_ops/cuda/tensor/gather_nd_grad.h" +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace cuda { + +#define REGISTER_KERNEL_TYPED_GATHER_ND_GRAD(TIndex) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + GatherNDGrad, \ + kMSDomain, \ + 1, \ + TIndex, \ + kCudaExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) \ + .TypeConstraint("Tind", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(0), \ + GatherNDGrad); + +REGISTER_KERNEL_TYPED_GATHER_ND_GRAD(int64_t) + +template +Status GatherNDGrad::ComputeInternal(OpKernelContext* context) const { + auto shape_tensor = context->Input(0); + auto indices_tensor = context->Input(1); + auto update_tensor = context->Input(2); + ORT_RETURN_IF_NOT(shape_tensor != nullptr); + ORT_RETURN_IF_NOT(indices_tensor != nullptr); + ORT_RETURN_IF_NOT(update_tensor != nullptr); + + auto indices_shape = indices_tensor->Shape(); + auto update_shape = update_tensor->Shape(); + + if (indices_shape.NumDimensions() == 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "indices tensor must has rank larger than 0"); + } + + auto last_indices_dimension = batch_dims_ + indices_shape[indices_shape.NumDimensions() - 1]; + + //Output + auto shape_data = shape_tensor->Data(); + auto input_shape = TensorShape(shape_data, shape_tensor->SizeInBytes() / sizeof(shape_tensor->DataType())); + + if (last_indices_dimension > static_cast(input_shape.NumDimensions())) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "last dimension of indices must not be larger than rank of input tensor"); + } + + ORT_RETURN_IF_ERROR(CheckBatchDimensionsMatch( + static_cast(batch_dims_), {input_shape, indices_shape, update_shape})); + + auto output_tensor = context->Output(0, input_shape); + + // TODO this memset can be expensive, a sparse tensor representation would help here + CUDA_RETURN_IF_ERROR(cudaMemsetAsync(output_tensor->MutableDataRaw(), 0, output_tensor->SizeInBytes())); + + auto status = CommonComputeKernel(batch_dims_, input_shape, update_tensor, output_tensor, indices_shape, indices_tensor, false); + return status; +} + +} // namespace cuda +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_grad.h b/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_grad.h new file mode 100644 index 0000000000..d29d85befa --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_grad.h @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/tensor/gather_nd.h" + +namespace onnxruntime { +namespace cuda { + +template +class GatherNDGrad final : public GatherNDBase { + public: + GatherNDGrad(const OpKernelInfo& info) : GatherNDBase(info) {} + Status ComputeInternal(OpKernelContext* context) const override; +}; + +} // namespace cuda +} // namespace onnxruntime \ No newline at end of file diff --git a/orttraining/orttraining/training_ops/cuda/tensor/thrustallocator.h b/orttraining/orttraining/training_ops/cuda/tensor/thrustallocator.h deleted file mode 100644 index 182b14b1a4..0000000000 --- a/orttraining/orttraining/training_ops/cuda/tensor/thrustallocator.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/framework/allocator.h" - -namespace onnxruntime { -namespace cuda { - -class ThrustAllocator { - public: - typedef char value_type; - - ThrustAllocator(IAllocator* alloc) : alloc_(alloc) {} - - char* allocate(std::ptrdiff_t size) { - return static_cast(alloc_->Alloc(size)); - } - - void deallocate(char* p, size_t /*size*/) { - alloc_->Free(p); - } - - private: - IAllocator* alloc_; -}; - -} // namespace cuda -} // namespace onnxruntime From 75c24a5facc535e1b28443523d3131089a760450 Mon Sep 17 00:00:00 2001 From: ytaous <4484531+ytaous@users.noreply.github.com> Date: Mon, 27 Apr 2020 20:42:43 -0700 Subject: [PATCH 7/8] Revert "Merge from ort_training to master (#3719)" (#3726) This reverts commit b990ba0059291addeb220e87a88ef9d2d7621baa. --- cmake/onnxruntime_providers.cmake | 10 + .../{cpu => }/cpu_contrib_kernels.cc | 2 +- .../{cpu => }/cpu_contrib_kernels.h | 0 .../{cuda => }/cuda_contrib_kernels.cc | 0 .../{cuda => }/cuda_contrib_kernels.h | 0 .../core/graph/contrib_ops/contrib_defs.cc | 71 ++++ .../providers/acl/acl_execution_provider.cc | 2 +- .../providers/cpu/cpu_execution_provider.cc | 10 +- .../core/providers/cpu/tensor/gather_nd.cc | 49 +-- .../core/providers/cpu/tensor/gather_nd.h | 5 +- .../providers/cuda/cuda_execution_provider.cc | 8 +- .../core/providers/cuda/gpu_data_transfer.cc | 5 +- .../{cpu => }/cpu_featurizers_kernels.cc | 2 +- .../{cpu => }/cpu_featurizers_kernels.h | 0 .../test/common/tensor_op_test_utils.h | 14 - .../providers/cpu/tensor/gather_nd_op_test.cc | 138 +------ .../core/graph/gradient_builder.cc | 8 +- .../core/graph/gradient_schema_defs.cc | 80 +--- .../core/graph/loss_func/bert_loss.cc | 4 +- .../python/orttraining_pybind_state.cc | 5 +- .../test/gradient/event_op_test.cc | 122 ------- .../test/gradient/gradient_checker.cc | 4 +- .../test/gradient/gradient_ops_test.cc | 119 ++++-- .../cpu/tensor/gather_nd_grad_op_test.cc | 80 ---- .../cpu/tensor/gather_nd_op_test.cc | 343 ++++++++++++++++++ .../training_ops/cpu/controlflow/common.h | 19 - .../cpu/controlflow/event_pool.cc | 16 +- .../training_ops/cpu/controlflow/event_pool.h | 4 - .../training_ops/cpu/controlflow/record.cc | 22 +- .../training_ops/cpu/controlflow/record.h | 3 - .../training_ops/cpu/controlflow/wait.cc | 31 +- .../training_ops/cpu/controlflow/wait.h | 5 +- .../training_ops/cpu/tensor/gather_nd.cc | 138 +++++++ .../training_ops/cpu/tensor/gather_nd.h | 53 +++ .../{cpu => }/cpu_training_kernels.cc | 4 +- .../{cpu => }/cpu_training_kernels.h | 0 .../training_ops/cuda/controlflow/record.cc | 43 --- .../training_ops/cuda/controlflow/record.h | 19 - .../training_ops/cuda/controlflow/wait.cc | 43 --- .../training_ops/cuda/controlflow/wait.h | 19 - .../training_ops/cuda/nn/dropout.cc | 5 +- .../training_ops/cuda/tensor/gather_grad.cc | 1 + .../training_ops}/cuda/tensor/gather_nd.cc | 100 +++-- .../training_ops}/cuda/tensor/gather_nd.h | 19 +- .../cuda/tensor/gather_nd_gard_impl.cu | 47 --- .../cuda/tensor/gather_nd_grad.cc | 67 ---- .../training_ops/cuda/tensor/gather_nd_grad.h | 22 -- .../cuda/tensor/gather_nd_impl.cu | 63 ++-- .../cuda/tensor/gather_nd_impl.h | 4 - .../cuda/tensor/thrustallocator.h | 30 ++ .../{cuda => }/cuda_training_kernels.cc | 18 +- .../{cuda => }/cuda_training_kernels.h | 0 52 files changed, 950 insertions(+), 926 deletions(-) rename onnxruntime/contrib_ops/{cpu => }/cpu_contrib_kernels.cc (99%) rename onnxruntime/contrib_ops/{cpu => }/cpu_contrib_kernels.h (100%) rename onnxruntime/contrib_ops/{cuda => }/cuda_contrib_kernels.cc (100%) rename onnxruntime/contrib_ops/{cuda => }/cuda_contrib_kernels.h (100%) rename onnxruntime/featurizers_ops/{cpu => }/cpu_featurizers_kernels.cc (99%) rename onnxruntime/featurizers_ops/{cpu => }/cpu_featurizers_kernels.h (100%) delete mode 100644 orttraining/orttraining/test/gradient/event_op_test.cc delete mode 100644 orttraining/orttraining/test/training_ops/cpu/tensor/gather_nd_grad_op_test.cc create mode 100644 orttraining/orttraining/test/training_ops/cpu/tensor/gather_nd_op_test.cc delete mode 100644 orttraining/orttraining/training_ops/cpu/controlflow/common.h create mode 100644 orttraining/orttraining/training_ops/cpu/tensor/gather_nd.cc create mode 100644 orttraining/orttraining/training_ops/cpu/tensor/gather_nd.h rename orttraining/orttraining/training_ops/{cpu => }/cpu_training_kernels.cc (98%) rename orttraining/orttraining/training_ops/{cpu => }/cpu_training_kernels.h (100%) delete mode 100644 orttraining/orttraining/training_ops/cuda/controlflow/record.cc delete mode 100644 orttraining/orttraining/training_ops/cuda/controlflow/record.h delete mode 100644 orttraining/orttraining/training_ops/cuda/controlflow/wait.cc delete mode 100644 orttraining/orttraining/training_ops/cuda/controlflow/wait.h rename {onnxruntime/core/providers => orttraining/orttraining/training_ops}/cuda/tensor/gather_nd.cc (62%) rename {onnxruntime/core/providers => orttraining/orttraining/training_ops}/cuda/tensor/gather_nd.h (72%) delete mode 100644 orttraining/orttraining/training_ops/cuda/tensor/gather_nd_gard_impl.cu delete mode 100644 orttraining/orttraining/training_ops/cuda/tensor/gather_nd_grad.cc delete mode 100644 orttraining/orttraining/training_ops/cuda/tensor/gather_nd_grad.h rename {onnxruntime/core/providers => orttraining/orttraining/training_ops}/cuda/tensor/gather_nd_impl.cu (68%) rename {onnxruntime/core/providers => orttraining/orttraining/training_ops}/cuda/tensor/gather_nd_impl.h (91%) create mode 100644 orttraining/orttraining/training_ops/cuda/tensor/thrustallocator.h rename orttraining/orttraining/training_ops/{cuda => }/cuda_training_kernels.cc (97%) rename orttraining/orttraining/training_ops/{cuda => }/cuda_training_kernels.h (100%) diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 389146a348..c91daf3c4b 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -7,11 +7,15 @@ file(GLOB_RECURSE onnxruntime_providers_srcs CONFIGURE_DEPENDS ) file(GLOB_RECURSE onnxruntime_cpu_contrib_ops_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu_contrib_kernels.h" + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu_contrib_kernels.cc" "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/*.h" "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/*.cc" ) file(GLOB_RECURSE onnxruntime_cuda_contrib_ops_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda_contrib_kernels.h" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda_contrib_kernels.cc" "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/*.h" "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/*.cc" ) @@ -22,6 +26,8 @@ file(GLOB_RECURSE onnxruntime_cuda_contrib_ops_cu_srcs CONFIGURE_DEPENDS ) file(GLOB onnxruntime_cpu_featurizers_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/featurizers_ops/cpu_featurizers_kernels.h" + "${ONNXRUNTIME_ROOT}/featurizers_ops/cpu_featurizers_kernels.cc" "${ONNXRUNTIME_ROOT}/featurizers_ops/cpu/*.h" "${ONNXRUNTIME_ROOT}/featurizers_ops/cpu/*.cc" ) @@ -89,6 +95,8 @@ endif() if (onnxruntime_ENABLE_TRAINING) file(GLOB_RECURSE onnxruntime_cpu_training_ops_srcs CONFIGURE_DEPENDS + "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu_training_kernels.h" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu_training_kernels.cc" "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/*.h" "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/*.cc" ) @@ -166,6 +174,8 @@ if (onnxruntime_USE_CUDA) if (onnxruntime_ENABLE_TRAINING) file(GLOB_RECURSE onnxruntime_cuda_training_ops_cc_srcs CONFIGURE_DEPENDS + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda_training_kernels.h" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda_training_kernels.cc" "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/*.h" "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/*.cc" ) diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu_contrib_kernels.cc similarity index 99% rename from onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc rename to onnxruntime/contrib_ops/cpu_contrib_kernels.cc index 444904aa2c..e9c9d9f873 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu_contrib_kernels.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "contrib_ops/cpu/cpu_contrib_kernels.h" +#include "contrib_ops/cpu_contrib_kernels.h" #include "core/graph/constants.h" #include "core/mlas/inc/mlas.h" diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.h b/onnxruntime/contrib_ops/cpu_contrib_kernels.h similarity index 100% rename from onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.h rename to onnxruntime/contrib_ops/cpu_contrib_kernels.h diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda_contrib_kernels.cc similarity index 100% rename from onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc rename to onnxruntime/contrib_ops/cuda_contrib_kernels.cc diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.h b/onnxruntime/contrib_ops/cuda_contrib_kernels.h similarity index 100% rename from onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.h rename to onnxruntime/contrib_ops/cuda_contrib_kernels.h diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index aaa2258fe6..ecb9f5ccb6 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -2077,6 +2077,77 @@ Output = Dequantize(Input) -> AveragePool on fp32 data -> Quantize(output) .SetDoc(R"DOC( Given `data` tensor of rank r >= 1, and `indices` tensor of rank q >= 1, gather slices of `data` into an output tensor of rank q - 1 + r - indices[-1]. +Example 1: + data = [[0,1],[2,3]] + indices = [[0,0],[1,1]] + output = [0,3] +Example 2: + data = [[0,1],[2,3]] + indices = [[1],[0]] + output = [[2,3],[0,1]] +Example 3: + data = [[[0,1],[2,3]],[[4,5],[6,7]]] + indices = [[0,1],[1,0]] + output = [[2,3],[4,5]] +Example 4: + data = [[[0,1],[2,3]],[[4,5],[6,7]]] + indices = [[[0,1]],[[1,0]]] + output = [[[2,3]],[[4,5]]] +)DOC"); + + ONNX_CONTRIB_OPERATOR_SCHEMA(GatherND) + .SetDomain(kOnnxDomain) + .SinceVersion(1) + .Attr( + "axis", + "The number of batch dims. The gather of indexing starts from dimension of data[axis:]", + AttributeProto::INT, + static_cast(0)) + .Input(0, "data", "Tensor of rank r >= 1.", "T") + .Input(1, "indices", "Tensor of rank q >= 1.", "Tind") + .Output(0, "output", "Tensor of rank q-1+r-indices[-1].", "T") + .TypeConstraint( + "T", + OpSchema::all_tensor_types(), + "Constrain input and output types to any tensor type.") + .TypeConstraint( + "Tind", + {"tensor(int32)", "tensor(int64)"}, + "Constrain indice type to int32 or int64") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + if (!hasNInputShapes(ctx, 2)) { + return; + } + auto& data_shape = ctx.getInputType(0)->tensor_type().shape(); + auto& indices_shape = ctx.getInputType(1)->tensor_type().shape(); + auto data_rank = data_shape.dim_size(); + auto indices_rank = indices_shape.dim_size(); + auto axis = ctx.getAttribute("axis"); + int64_t axis_data = axis ? static_cast(axis->i()) : 0; + if (data_rank < 1 || indices_rank < 1) { + fail_shape_inference("both data and indices tensor need to have rank larger than zero."); + } + auto last_indice_dimension = indices_shape.dim(indices_rank - 1).dim_value() + axis_data; + if (last_indice_dimension > data_rank) { + fail_shape_inference("last dimension of indices must not be larger and rank of data tensor"); + } + for (int i = 0; i < indices_rank - 1; ++i) { + *ctx.getOutputType(0) + ->mutable_tensor_type() + ->mutable_shape() + ->add_dim() = indices_shape.dim(i); + } + for (int i = static_cast(last_indice_dimension); i < data_rank; ++i) { + *ctx.getOutputType(0) + ->mutable_tensor_type() + ->mutable_shape() + ->add_dim() = data_shape.dim(i); + } + }) + .SetDoc(R"DOC( +Given `data` tensor of rank r >= 1, and `indices` tensor of rank q >= 1, gather +slices of `data` into an output tensor of rank q - 1 + r - indices[-1]. Example 1: data = [[0,1],[2,3]] indices = [[0,0],[1,1]] diff --git a/onnxruntime/core/providers/acl/acl_execution_provider.cc b/onnxruntime/core/providers/acl/acl_execution_provider.cc index 56ff6a759c..8319121abf 100644 --- a/onnxruntime/core/providers/acl/acl_execution_provider.cc +++ b/onnxruntime/core/providers/acl/acl_execution_provider.cc @@ -7,7 +7,7 @@ #include "core/framework/op_kernel.h" #include "core/framework/kernel_registry.h" #include "core/framework/compute_capability.h" -#include "contrib_ops/cpu/cpu_contrib_kernels.h" +#include "contrib_ops/cpu_contrib_kernels.h" #include "acl_fwd.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index c28e966899..21c6b9caea 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -6,15 +6,15 @@ #include "core/framework/kernel_registry.h" #ifndef DISABLE_CONTRIB_OPS -#include "contrib_ops/cpu/cpu_contrib_kernels.h" +#include "contrib_ops/cpu_contrib_kernels.h" #endif #ifdef ML_FEATURIZERS -#include "featurizers_ops/cpu/cpu_featurizers_kernels.h" +#include "featurizers_ops/cpu_featurizers_kernels.h" #endif #ifdef ENABLE_TRAINING -#include "orttraining/training_ops/cpu/cpu_training_kernels.h" +#include "orttraining/training_ops/cpu_training_kernels.h" #endif #include "core/framework/compute_capability.h" @@ -435,7 +435,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, int8_t, ReduceMin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, uint8_t, ReduceMin); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, GatherND); + Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { @@ -1082,8 +1082,6 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { ReduceMin)>, BuildKernelCreateInfo, - - BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/cpu/tensor/gather_nd.cc b/onnxruntime/core/providers/cpu/tensor/gather_nd.cc index ae0e77e835..9868cd3178 100644 --- a/onnxruntime/core/providers/cpu/tensor/gather_nd.cc +++ b/onnxruntime/core/providers/cpu/tensor/gather_nd.cc @@ -25,22 +25,12 @@ ONNX_OPERATOR_KERNEL_EX(GatherND, kMSDomain, 1, kCpuExecutionProvider, #endif -ONNX_CPU_OPERATOR_KERNEL( - GatherND, - 11, - KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::AllTensorTypes()) - // official ONNX spec only supports `int64_t` for indices - .TypeConstraint("Tind", DataTypeImpl::GetTensorType()), - GatherND); - -ONNX_CPU_OPERATOR_KERNEL( - GatherND, - 12, - KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::AllTensorTypes()) - .TypeConstraint("Tind", DataTypeImpl::GetTensorType()), - GatherND); +ONNX_CPU_OPERATOR_KERNEL(GatherND, 11, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::AllTensorTypes()) + // official ONNX spec only supports `int64_t` for indices + .TypeConstraint("Tind", DataTypeImpl::GetTensorType()), + GatherND); template Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) const { @@ -54,7 +44,7 @@ Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) con return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "indices tensor must has rank larger than 0"); } - int64_t last_indices_dimension = indices_shape[indices_shape.NumDimensions() - 1] + batch_dims_; + int64_t last_indices_dimension = indices_shape[indices_shape.NumDimensions() - 1]; if (last_indices_dimension > static_cast(input_shape.NumDimensions())) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "last dimension of indices must not be larger than rank of input tensor"); @@ -63,7 +53,7 @@ Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) con std::vector shape(indices_shape.GetDims().begin(), indices_shape.GetDims().end() - 1); shape.insert(shape.end(), input_shape.GetDims().begin() + last_indices_dimension, input_shape.GetDims().end()); auto* output_tensor = context->Output(0, TensorShape(std::move(shape))); - std::vector element_counts(last_indices_dimension + batch_dims_, + std::vector element_counts(last_indices_dimension, 0LL); // Number of elements for each input dimension #ifdef _OPENMP @@ -73,20 +63,12 @@ Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) con element_counts[i] = input_shape.SizeFromDimension(i + 1); } - auto last_dim_size = indices_shape.SizeFromDimension(indices_shape.NumDimensions() - 1); -#ifdef USE_OPENMP -#pragma omp parallel for -#endif - for (int64_t i = batch_dims_ - 1; i >= 0; --i) { - element_counts[last_indices_dimension + i] = indices_shape.SizeFromDimension(i + 1) / last_dim_size; - } - int64_t err_index = 0; p.element_bytes = input_tensor->DataType()->Size(); p.element_to_copy = input_shape.SizeFromDimension(last_indices_dimension); p.bytes_to_copy = p.element_bytes * p.element_to_copy; - const auto* indice_offset = indices_tensor->Data(); - const int64_t offset_count = indices_shape.Size() / (last_indices_dimension - batch_dims_); // Times to copy + const auto* indices_data = indices_tensor->Data(); + const int64_t offset_count = indices_shape.Size() / last_indices_dimension; // Times to copy p.element_offsets.assign(offset_count, 0LL); if (input_tensor->IsDataTypeString()) { @@ -97,19 +79,12 @@ Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) con p.output_base = static_cast(output_tensor->MutableDataRaw()); } - //Compute the element_offset #ifdef _OPENMP #pragma omp parallel for #endif for (int64_t i = 0; i < offset_count; ++i) { - int64_t reminder = i; - for (int64_t j = 0; j < batch_dims_; ++j) { - int64_t idx = reminder / element_counts[last_indices_dimension + j]; - p.element_offsets[i] += idx * element_counts[j]; - reminder -= (idx * element_counts[last_indices_dimension + j]); - } - for (int64_t j = batch_dims_; j < last_indices_dimension; ++j) { - auto index = *(indice_offset + i * (last_indices_dimension - batch_dims_) + (j - batch_dims_)); + for (int64_t j = 0; j < last_indices_dimension; ++j) { + auto index = *(indices_data + i * last_indices_dimension + j); auto upper_limit = input_shape[j]; auto lower_limit = -upper_limit; if (index < lower_limit || index >= upper_limit) { diff --git a/onnxruntime/core/providers/cpu/tensor/gather_nd.h b/onnxruntime/core/providers/cpu/tensor/gather_nd.h index 135706a245..a169c5ab78 100644 --- a/onnxruntime/core/providers/cpu/tensor/gather_nd.h +++ b/onnxruntime/core/providers/cpu/tensor/gather_nd.h @@ -33,14 +33,11 @@ class GatherNDBase { template Status PrepareForCompute(OpKernelContext* context, Prepare& p) const; - int64_t batch_dims_; }; // class GatherNDBase class GatherND final : public OpKernel, protected GatherNDBase { public: - explicit GatherND(const OpKernelInfo& info) : OpKernel(info) { - info.GetAttrOrDefault("batch_dims", &batch_dims_, static_cast(0)); - } + explicit GatherND(const OpKernelInfo& info) : OpKernel(info) {} Status Compute(OpKernelContext* context) const override; private: diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 1611b6c5b4..88c101d429 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -12,11 +12,11 @@ #include "core/providers/cuda/gpu_data_transfer.h" #ifndef DISABLE_CONTRIB_OPS -#include "contrib_ops/cuda/cuda_contrib_kernels.h" +#include "contrib_ops/cuda_contrib_kernels.h" #endif #ifdef ENABLE_TRAINING -#include "orttraining/training_ops/cuda/cuda_training_kernels.h" +#include "orttraining/training_ops/cuda_training_kernels.h" #endif using namespace onnxruntime::common; @@ -769,8 +769,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, int8_t, ReduceMin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, uint8_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, int64_t, GatherND); - static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, @@ -1283,8 +1281,6 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - - BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/cuda/gpu_data_transfer.cc b/onnxruntime/core/providers/cuda/gpu_data_transfer.cc index 08ff82cee0..8fae7ae8b0 100644 --- a/onnxruntime/core/providers/cuda/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/cuda/gpu_data_transfer.cc @@ -36,10 +36,7 @@ common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst, int e CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyHostToDevice, streams_[exec_queue_id])); } else if (src_device.Type() == OrtDevice::GPU) { // copying between GPU, this is non-blocking - // Copy only if the two addresses are different. - if (dst_data != src_data) { - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice, streams_[kCudaStreamDefault])); - } + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice, streams_[kCudaStreamDefault])); } else { // copy from other CPU memory to GPU, this is blocking CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyHostToDevice)); diff --git a/onnxruntime/featurizers_ops/cpu/cpu_featurizers_kernels.cc b/onnxruntime/featurizers_ops/cpu_featurizers_kernels.cc similarity index 99% rename from onnxruntime/featurizers_ops/cpu/cpu_featurizers_kernels.cc rename to onnxruntime/featurizers_ops/cpu_featurizers_kernels.cc index 4a7c10dd1f..854dcc9063 100644 --- a/onnxruntime/featurizers_ops/cpu/cpu_featurizers_kernels.cc +++ b/onnxruntime/featurizers_ops/cpu_featurizers_kernels.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "featurizers_ops/cpu/cpu_featurizers_kernels.h" +#include "featurizers_ops/cpu_featurizers_kernels.h" #include "core/graph/constants.h" #include "core/framework/data_types.h" diff --git a/onnxruntime/featurizers_ops/cpu/cpu_featurizers_kernels.h b/onnxruntime/featurizers_ops/cpu_featurizers_kernels.h similarity index 100% rename from onnxruntime/featurizers_ops/cpu/cpu_featurizers_kernels.h rename to onnxruntime/featurizers_ops/cpu_featurizers_kernels.h diff --git a/onnxruntime/test/common/tensor_op_test_utils.h b/onnxruntime/test/common/tensor_op_test_utils.h index 5259d9c5f2..1122e3cb56 100644 --- a/onnxruntime/test/common/tensor_op_test_utils.h +++ b/onnxruntime/test/common/tensor_op_test_utils.h @@ -55,20 +55,6 @@ inline std::vector FillZeros(const std::vector& dims) { return val; } -// Returns a vector of `count` values which start at `start` and change by increments of `step`. -template -inline std::vector ValueRange( - size_t count, T start = static_cast(0), T step = static_cast(1)) { - std::vector result; - result.reserve(count); - T curr = start; - for (size_t i = 0; i < count; ++i) { - result.emplace_back(curr); - curr += step; - } - return result; -} - inline std::pair MeanStdev(std::vector& v) { float sum = std::accumulate(v.begin(), v.end(), 0.0f); float mean = sum / v.size(); diff --git a/onnxruntime/test/providers/cpu/tensor/gather_nd_op_test.cc b/onnxruntime/test/providers/cpu/tensor/gather_nd_op_test.cc index 7b15687120..91164a608f 100644 --- a/onnxruntime/test/providers/cpu/tensor/gather_nd_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/gather_nd_op_test.cc @@ -3,8 +3,6 @@ #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" -#include "test/common/cuda_op_test_utils.h" -#include "test/common/tensor_op_test_utils.h" namespace onnxruntime { namespace test { @@ -20,22 +18,15 @@ static void RunTest(const std::vector& input_dims, const std::initializ test1.AddOutput("output", output_dims, output); test1.Run(); - // ONNX domain opset-12 - OpTester test2("GatherND", 12); +#ifndef DISABLE_CONTRIB_OPS + + // MSFT domain opset-1 (contrib op) + OpTester test2("GatherND", 1, kMSDomain); test2.AddInput("data", input_dims, input); test2.AddInput("indices", indices_dims, indices); test2.AddOutput("output", output_dims, output); test2.Run(); -#ifndef DISABLE_CONTRIB_OPS - - // MSFT domain opset-1 (contrib op) - OpTester test3("GatherND", 1, kMSDomain); - test3.AddInput("data", input_dims, input); - test3.AddInput("indices", indices_dims, indices); - test3.AddOutput("output", output_dims, output); - test3.Run(); - #endif } @@ -79,21 +70,11 @@ TEST(GatherNDOpTest, int64_t) { } TEST(GatherNDOpTest, float) { - if (!HasCudaEnvironment(600 /*min_cuda_architecture*/)) return; - RunTest({2, 2}, {0.0f, 0.1f, 0.2f, 0.3f}, {2, 1}, {1LL, 0LL}, {2, 2}, {0.2f, 0.3f, 0.0f, 0.1f}); - - // with negative indices - RunTest({2, 2}, {0.0f, 0.1f, 0.2f, 0.3f}, {2, 1}, {-1LL, 0LL}, {2, 2}, {0.2f, 0.3f, 0.0f, 0.1f}); } TEST(GatherNDOpTest, double) { - if (!HasCudaEnvironment(600 /*min_cuda_architecture*/)) return; - RunTest({2, 2}, {0.0, 0.1, 0.2, 0.3}, {2, 1}, {1LL, 0LL}, {2, 2}, {0.2, 0.3, 0.0, 0.1}); - - // with negative indices - RunTest({2, 2}, {0.0, 0.1, 0.2, 0.3}, {2, 1}, {-1LL, 0LL}, {2, 2}, {0.2, 0.3, 0.0, 0.1}); } TEST(GatherNDOpTest, int8_t) { @@ -133,116 +114,5 @@ TEST(GatherNDOpTest, ContribOpInt32Indices) { #endif -TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_0) { - OpTester test("GatherND", 12, kOnnxDomain); - test.AddAttribute("batch_dims", 0); - test.AddInput("data", {2, 3, 4}, ValueRange(24, 1.0f)); - test.AddInput("indices", {3, 2, 2}, {0LL, 1LL, 0LL, 2LL, 1LL, 0LL, 0LL, 0LL, 1LL, 1LL, 1LL, 2LL}); - test.AddOutput("output", {3, 2, 4}, {5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 1.0, 2.0, 3.0, 4.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}); - test.Run(); -} - -TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_1) { - OpTester test("GatherND", 12, kOnnxDomain); - test.AddAttribute("batch_dims", 1); - test.AddInput("data", {2, 3, 4}, ValueRange(24, 1.0f)); - test.AddInput("indices", {2, 2, 2}, {0LL, 1LL, 0LL, 2LL, 1LL, 0LL, 0LL, 0LL}); - test.AddOutput("output", {2, 2}, {2.0, 3.0, 17.0, 13.0}); - test.Run(); -} - -TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_2) { - OpTester test("GatherND", 12, kOnnxDomain); - test.AddAttribute("batch_dims", 1); - test.AddInput("data", {2, 2, 2}, ValueRange(8, 0.0f, 0.1f)); - test.AddInput("indices", {2, 1}, {1LL, 0LL}); - test.AddOutput("output", {2, 2}, {0.2f, 0.3f, 0.4f, 0.5f}); - test.Run(); -} - -#ifdef USE_CUDA -#if __CUDA_ARCH__ >= 600 -TEST(GatherNDOpTest, GatherND_slice_double_batch_dims_3) { - OpTester test("GatherND", 12, kOnnxDomain); - test.AddAttribute("batch_dims", 1); - test.AddInput("data", {2, 2, 2}, ValueRange(8, 0.0f, 0.1f)); - test.AddInput("indices", {2, 1, 1}, {1LL, 0LL}); - test.AddOutput("output", {2, 1, 2}, {0.2f, 0.3f, 0.4f, 0.5f}); - test.Run(); -} - -TEST(GatherNDOpTest, GatherND_slice_double) { - OpTester test("GatherND", 12, kOnnxDomain); - test.AddInput("data", {2, 2}, {0.0f, 0.1f, 0.2f, 0.3f}); - test.AddInput("indices", {2, 1}, {1LL, 0LL}); - test.AddOutput("output", {2, 2}, {0.2f, 0.3f, 0.0f, 0.1f}); - test.Run(); -} -#endif -#endif - -TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_4) { - OpTester test("GatherND", 12, kOnnxDomain); - test.AddAttribute("batch_dims", 1); - test.AddInput("data", {2, 2, 2}, ValueRange(8, 0.0f, 0.1f)); - test.AddInput("indices", {2, 1, 2}, {1LL, 0LL, 0LL, 1LL}); - test.AddOutput("output", {2, 1}, {0.2f, 0.5f}); - test.Run(); -} - -#ifdef USE_CUDA - -TEST(GatherNDOpTest, GatherND_slice_double_batch_dims_3) { - if (!HasCudaEnvironment(600 /*min_cuda_architecture*/)) return; - - OpTester test("GatherND", 12, kOnnxDomain); - test.AddAttribute("batch_dims", 1); - test.AddInput("data", {2, 2, 2}, ValueRange(8, 0.0, 0.1)); - test.AddInput("indices", {2, 1, 1}, {1LL, 0LL}); - test.AddOutput("output", {2, 1, 2}, {0.2f, 0.3f, 0.4f, 0.5f}); - test.Run(); -} - -TEST(GatherNDOpTest, GatherND_slice_half) { - if (!HasCudaEnvironment(600 /*min_cuda_architecture*/)) return; - - OpTester test("GatherND", 12, kOnnxDomain); - std::vector data_f({0.0f, 0.1f, 0.2f, 0.3f}); - std::vector outputs_f({0.2f, 0.3f, 0.0f, 0.1f}); - std::vector data(4); - std::vector outputs(4); - ConvertFloatToMLFloat16(data_f.data(), data.data(), 4); - ConvertFloatToMLFloat16(outputs_f.data(), outputs.data(), 4); - test.AddInput("data", {2, 2}, data); - test.AddInput("indices", {2, 1}, {1LL, 0LL}); - test.AddOutput("output", {2, 2}, outputs); - test.Run(); -} - -TEST(GatherNDOpTest, GatherND_batch_dims_of_2) { - OpTester test("GatherND", 12, kOnnxDomain); - test.AddAttribute("batch_dims", 2); - test.AddInput("data", {2, 2, 2, 2, 3}, ValueRange(48)); - test.AddInput( - "indices", {2, 2, 1, 2}, - { - 0, 0, // batch 0 - 1, 0, // batch 1 - 1, 1, // batch 2 - 0, 1, // batch 3 - }); - test.AddOutput( - "output", {2, 2, 1, 3}, - { - 0, 1, 2, // batch 0 - 18, 19, 20, // batch 1 - 33, 34, 35, // batch 2 - 39, 40, 41, // batch 3 - }); - test.Run(); -} - -#endif - } // namespace test } // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 0c19e809ca..609ff30304 100644 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -438,16 +438,16 @@ IMPLEMENT_GRADIENT_BUILDER(GetConcatGradient) { IMPLEMENT_GRADIENT_BUILDER(GetGatherNDGradient) { auto attributes = SrcNodeAttributes(); - ORT_ENFORCE(attributes.at("batch_dims").has_i()); - auto batch_dims = attributes.at("batch_dims").i(); + ORT_ENFORCE(attributes.at("axis").has_i()); + auto axis = attributes.at("axis").i(); return std::vector{ NodeDef("Shape", {I(0)}, {IA("x_shape")}), - NodeDef(OpDef{"GatherNDGrad", kMSDomain, 1}, + NodeDef("GatherNDGrad", {IA("x_shape"), I(1), GO(0)}, {GI(0)}, - {MakeAttribute("batch_dims", batch_dims)})}; + {MakeAttribute("axis", axis)})}; }; IMPLEMENT_GRADIENT_BUILDER(GetReshapeGradient) { diff --git a/orttraining/orttraining/core/graph/gradient_schema_defs.cc b/orttraining/orttraining/core/graph/gradient_schema_defs.cc index 8657a27f63..9249dd8410 100644 --- a/orttraining/orttraining/core/graph/gradient_schema_defs.cc +++ b/orttraining/orttraining/core/graph/gradient_schema_defs.cc @@ -698,84 +698,12 @@ void RegisterGradientSchemas() { propagateShapeAndTypeFromFirstInput(ctx); }); - // TODO: Depreacate this schema when training support is udpated to opset-12 - ONNX_CONTRIB_OPERATOR_SCHEMA(GatherND) + ONNX_CONTRIB_OPERATOR_SCHEMA(GatherNDGrad) .SetDomain(kOnnxDomain) .SinceVersion(1) .Attr( - "batch_dims", - "The number of batch dims. The gather of indexing starts from dimension of data[batch_dims:]", - AttributeProto::INT, - static_cast(0)) - .Input(0, "data", "Tensor of rank r >= 1.", "T") - .Input(1, "indices", "Tensor of rank q >= 1.", "Tind") - .Output(0, "output", "Tensor of rank q-1+r-indices[-1].", "T") - .TypeConstraint( - "T", - OpSchema::all_tensor_types(), - "Constrain input and output types to any tensor type.") - .TypeConstraint( - "Tind", - {"tensor(int32)", "tensor(int64)"}, - "Constrain indice type to int32 or int64") - .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - propagateElemTypeFromInputToOutput(ctx, 0, 0); - if (!hasNInputShapes(ctx, 2)) { - return; - } - auto& data_shape = ctx.getInputType(0)->tensor_type().shape(); - auto& indices_shape = ctx.getInputType(1)->tensor_type().shape(); - auto data_rank = data_shape.dim_size(); - auto indices_rank = indices_shape.dim_size(); - auto batch_dims = ctx.getAttribute("batch_dims"); - int64_t batch_dims_data = batch_dims ? static_cast(batch_dims->i()) : 0; - if (data_rank < 1 || indices_rank < 1) { - fail_shape_inference("both data and indices tensor need to have rank larger than zero."); - } - auto last_indice_dimension = indices_shape.dim(indices_rank - 1).dim_value() + batch_dims_data; - if (last_indice_dimension > data_rank) { - fail_shape_inference("last dimension of indices must not be larger and rank of data tensor"); - } - for (int i = 0; i < indices_rank - 1; ++i) { - *ctx.getOutputType(0) - ->mutable_tensor_type() - ->mutable_shape() - ->add_dim() = indices_shape.dim(i); - } - for (int i = static_cast(last_indice_dimension); i < data_rank; ++i) { - *ctx.getOutputType(0) - ->mutable_tensor_type() - ->mutable_shape() - ->add_dim() = data_shape.dim(i); - } - }) - .SetDoc(R"DOC( -Given `data` tensor of rank r >= 1, and `indices` tensor of rank q >= 1, gather -slices of `data` into an output tensor of rank q - 1 + r - indices[-1]. -Example 1: - data = [[0,1],[2,3]] - indices = [[0,0],[1,1]] - output = [0,3] -Example 2: - data = [[0,1],[2,3]] - indices = [[1],[0]] - output = [[2,3],[0,1]] -Example 3: - data = [[[0,1],[2,3]],[[4,5],[6,7]]] - indices = [[0,1],[1,0]] - output = [[2,3],[4,5]] -Example 4: - data = [[[0,1],[2,3]],[[4,5],[6,7]]] - indices = [[[0,1]],[[1,0]]] - output = [[[2,3]],[[4,5]]] -)DOC"); - - ONNX_CONTRIB_OPERATOR_SCHEMA(GatherNDGrad) - .SetDomain(kMSDomain) - .SinceVersion(1) - .Attr( - "batch_dims", - "The number of batch dims. The gather of indexing starts from dimension of data[batch_dims+1:]", + "axis", + "The number of batch dims. The gather of indexing starts from dimension of data[axis+1:]", AttributeProto::INT, static_cast(0)) .Input(0, "shape", "The shape of source data input of GatherND.", "T1") @@ -788,7 +716,7 @@ Example 4: "Constrain input and output types to any tensor type.") .TypeConstraint( "Tind", - {"tensor(int64)"}, + {"tensor(int32)", "tensor(int64)"}, "Constrain indice type to int32 or int64") .TypeConstraint( "T1", diff --git a/orttraining/orttraining/core/graph/loss_func/bert_loss.cc b/orttraining/orttraining/core/graph/loss_func/bert_loss.cc index b3237c2405..a75f9faa5a 100644 --- a/orttraining/orttraining/core/graph/loss_func/bert_loss.cc +++ b/orttraining/orttraining/core/graph/loss_func/bert_loss.cc @@ -83,10 +83,10 @@ GraphAugmenter::GraphDefs BertLoss::operator()(const Graph& graph, const LossFun "Mask_LM_Positions_Unsqueezed")); TypeProto* gathered_prediction_type_proto = GetGatheredPredictionTypeProto(prediction_arg, graph_defs); - new_nodes.emplace_back(NodeDef(OpDef{"GatherND", kOnnxDomain, 12}, + new_nodes.emplace_back(NodeDef("GatherND", {ArgDef(prediction_masked_lm), ArgDef("masked_lm_positions_unsqueezed")}, {ArgDef("gathered_prediction", gathered_prediction_type_proto)}, - {ONNX_NAMESPACE::MakeAttribute("batch_dims", static_cast(1))}, + {ONNX_NAMESPACE::MakeAttribute("axis", static_cast(1))}, "GATHERED_LM")); TypeProto* masked_lm_float_type_proto = GetMaskedLMTypeProto(prediction_arg, diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index 4628bad416..8e68e15d5c 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -94,7 +94,7 @@ TrainingConfigurationResult ConfigureSessionForTraining( config.weight_names_to_not_train = parameters.weights_not_to_train; config.immutable_weights = parameters.immutable_weights; - config.set_gradients_as_graph_outputs = false; + config.set_gradients_as_graph_outputs = true; config.gradient_accumulation_steps = parameters.gradient_accumulation_steps; @@ -115,7 +115,6 @@ TrainingConfigurationResult ConfigureSessionForTraining( config.loss_name = parameters.loss_output_name; if (!parameters.training_optimizer_name.empty()) { - config.set_gradients_as_graph_outputs = true; training::TrainingSession::TrainingConfiguration::OptimizerConfiguration opt{}; opt.name = parameters.training_optimizer_name; opt.learning_rate_input_name = parameters.lr_params_feed_name; @@ -277,4 +276,4 @@ void addObjectMethodsForTraining(py::module& m) { } } // namespace python -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file diff --git a/orttraining/orttraining/test/gradient/event_op_test.cc b/orttraining/orttraining/test/gradient/event_op_test.cc deleted file mode 100644 index 881ed4fa5d..0000000000 --- a/orttraining/orttraining/test/gradient/event_op_test.cc +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include -#include -#include -#include - -#include "gtest/gtest.h" - -#include "test/common/tensor_op_test_utils.h" -#include "test/providers/provider_test_utils.h" -#include "test/util/include/test_random_seed.h" -#include "test/util/include/default_providers.h" - -#include "onnx/defs/attr_proto_util.h" - -namespace onnxruntime { -namespace test { - -// Run GPU op for GPU build. Otherwise, run GPU op. -void run_provider_specific_optest(OpTester& tester) { - RunOptions run_option; -#ifdef USE_CUDA - std::vector> providers; - providers.push_back(DefaultCudaExecutionProvider()); -#else - std::vector> providers; - providers.push_back(DefaultCpuExecutionProvider()); -#endif - tester.Run( - OpTester::ExpectResult::kExpectSuccess, - "", - std::unordered_set(), - &run_option, - &providers); -} - -void record_event(int64_t event_id) { - OpTester test_record("RecordEvent", 1, onnxruntime::kMSDomain); - test_record.AddInput("EventIdentifier", {}, {event_id}); - test_record.AddInput("InputSignal", {}, {true}); - test_record.AddOutput("OutputSignal", {}, {true}); - run_provider_specific_optest(test_record); -} - -void record_event_multiple_inputs_and_outputs(int64_t event_id) { - OpTester test_record("RecordEvent", 1, onnxruntime::kMSDomain); - test_record.AddInput("EventIdentifier", {}, {event_id}); - test_record.AddInput("InputSignal", {}, {true}); - test_record.AddInput("Input1", {3}, {9.4f, 1.7f, 3.6f}); - test_record.AddInput("Input2", {1}, {1.6f}); - test_record.AddOutput("OutputSignal", {}, {true}); - test_record.AddOutput("Output1", {3}, {9.4f, 1.7f, 3.6f}); - test_record.AddOutput("Output2", {1}, {1.6f}); - run_provider_specific_optest(test_record); -} - -void wait_event(int64_t event_id) { - OpTester test_wait("WaitEvent", 1, onnxruntime::kMSDomain); - test_wait.AddInput("EventIdentifier", {}, {event_id}); - test_wait.AddInput("InputSignal", {}, {true}); - test_wait.AddOutput("OutputSignal", {}, {true}); - run_provider_specific_optest(test_wait); -} - -void wait_event_multiple_inputs_and_outputs(int64_t event_id) { - OpTester test_wait("WaitEvent", 1, onnxruntime::kMSDomain); - test_wait.AddInput("EventIdentifier", {}, {event_id}); - test_wait.AddInput("InputSignal", {}, {true}); - test_wait.AddInput("Input1", {1}, {1.6f}); - test_wait.AddInput("Input2", {3}, {9.4f, 1.7f, 3.6f}); - test_wait.AddOutput("OutputSignal", {}, {true}); - test_wait.AddOutput("output1", {1}, {1.6f}); - test_wait.AddOutput("output2", {3}, {9.4f, 1.7f, 3.6f}); - run_provider_specific_optest(test_wait); -} - -TEST(Synchronization, RecordAndWaitEvent) { - const int64_t event_id = static_cast(1736); - record_event(event_id); - wait_event(event_id); -} - -TEST(Synchronization, WaitNullEvent) { - wait_event(-1); -} - -TEST(Synchronization, RecordAndWaitEventMultipleInputsAndOutputs) { - const int64_t event_id = static_cast(995); - record_event_multiple_inputs_and_outputs(event_id); - wait_event_multiple_inputs_and_outputs(event_id); -} - -TEST(Synchronization, WaitAndRecordEvent) { - const int64_t event_id = static_cast(1228); - std::thread waiting_thread(wait_event, event_id); - std::this_thread::sleep_for(std::chrono::milliseconds(5)); - std::thread recording_thread(record_event, event_id); - - waiting_thread.join(); - recording_thread.join(); -} - -TEST(Synchronization, WaitAndRecordEventMany) { - const size_t event_count = 16; - for (int i = 0; i < 8; ++i) { - std::thread thread_pool[2 * event_count]; - for (int j = 0; j < static_cast(event_count); ++j) { - thread_pool[j] = std::thread(wait_event, j); - thread_pool[j + event_count] = std::thread(record_event, j); - } - for (size_t j = 0; j < event_count; ++j) { - thread_pool[j].join(); - thread_pool[j + event_count].join(); - } - } -} - -} // namespace test -} // namespace onnxruntime \ No newline at end of file diff --git a/orttraining/orttraining/test/gradient/gradient_checker.cc b/orttraining/orttraining/test/gradient/gradient_checker.cc index 4fc263a461..c12fae5cd1 100644 --- a/orttraining/orttraining/test/gradient/gradient_checker.cc +++ b/orttraining/orttraining/test/gradient/gradient_checker.cc @@ -457,8 +457,8 @@ inline Status GradientChecker::ComputeGradientErrorInternal( // TODO: These 4 test failed at following ORT_ENFORCE. need investigate before enable it. //GradientCheckerTest.MatMulGrad //GradientCheckerTest.GemmGrad - //GradientCheckerTest.GatherNDGrad_repeat_float_data - //GradientCheckerTest.GatherNDGrad_unique_float_data + //GradientCheckerTest.GatherNDGrad_int64_indice_repeat_float_data + //GradientCheckerTest.GatherNDGrad_int64_indice_unique_float_data //auto jac_t = jacobian_ts[j]; //ORT_ENFORCE(std::all_of( // &jac_t[0], &jac_t[0] + x_info.shape.Size(), [](auto dx) { return dx == 0; })); diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index a407d07e7d..556160bd58 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -1535,62 +1535,68 @@ TEST(GradientCheckerTest, DISABLED_DropoutGrad) { } } -TEST(GradientCheckerTest, GatherNDGrad_repeat_float_data) { +TEST(GradientCheckerTest, GatherNDGrad_int64_indice_repeat_float_data) { float max_error; GradientChecker gradient_checker; - OpDef op_def{"GatherND", kOnnxDomain, 12}; + OpDef op_def{"GatherND"}; TensorInfo x_info({2, 2}, true); TensorInfo indice_info({2, 2}, false, nullptr, DataTypeImpl::GetTensorType()); std::vector> x_datas = {{0, 1, 2, 3}, {1, 1, 1, 1}}; TensorInfo y_info({2}, true); - int64_t batch_dims = 0; + int64_t axis = 0; - gradient_checker.ComputeGradientError(op_def, {x_info, indice_info}, {y_info}, &max_error, x_datas, {MakeAttribute("batch_dims", batch_dims)}); + gradient_checker.ComputeGradientError(op_def, {x_info, indice_info}, {y_info}, &max_error, x_datas, {MakeAttribute("axis", axis)}); EXPECT_IS_TINY(max_error); } -TEST(GradientCheckerTest, GatherNDGrad_unique_float_data) { +TEST(GradientCheckerTest, GatherNDGrad_int64_indice_unique_float_data) { float max_error; GradientChecker gradient_checker; - OpDef op_def{"GatherND", kOnnxDomain, 12}; + OpDef op_def{"GatherND"}; - { - TensorInfo x_info({2, 2}, true); - TensorInfo indice_info({2, 2}, false, nullptr, DataTypeImpl::GetTensorType()); - std::vector> x_datas = {{0, 1, 2, 3}, {0, 1, 1, 0}}; + TensorInfo x_info({2, 2}, true); + TensorInfo indice_info({2, 2}, false, nullptr, DataTypeImpl::GetTensorType()); + std::vector> x_datas = {{0, 1, 2, 3}, {0, 1, 1, 0}}; - TensorInfo y_info({2}, true); - int64_t batch_dims = 0; + TensorInfo y_info({2}, true); + int64_t axis = 0; - gradient_checker.ComputeGradientError(op_def, {x_info, indice_info}, {y_info}, &max_error, x_datas, {MakeAttribute("batch_dims", batch_dims)}); - EXPECT_IS_TINY(max_error); - } + gradient_checker.ComputeGradientError(op_def, {x_info, indice_info}, {y_info}, &max_error, x_datas, {MakeAttribute("axis", axis)}); + EXPECT_IS_TINY(max_error); +} - { - TensorInfo x_info({2, 2, 3}, true); - TensorInfo indice_info({2, 1}, false, nullptr, DataTypeImpl::GetTensorType()); - std::vector> x_datas = {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {1, 0}}; +TEST(GradientCheckerTest, GatherNDGrad_int32_indice_unique_float_data) { + float max_error; + GradientChecker gradient_checker; + OpDef op_def{"GatherND"}; - TensorInfo y_info({2, 3}, true); - int64_t batch_dims = 1; + TensorInfo x_info({2, 2, 3}, true); + TensorInfo indice_info({2, 1}, false, nullptr, DataTypeImpl::GetTensorType()); + std::vector> x_datas = {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {1, 0}}; - gradient_checker.ComputeGradientError(op_def, {x_info, indice_info}, {y_info}, &max_error, x_datas, {MakeAttribute("batch_dims", batch_dims)}); - EXPECT_IS_TINY(max_error); - } + TensorInfo y_info({2, 3}, true); + int64_t axis = 1; - { - TensorInfo x_info({2, 2, 3}, true); - TensorInfo indice_info({2, 2, 1}, false, nullptr, DataTypeImpl::GetTensorType()); - std::vector> x_datas = {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {1, 0, 2, 1}}; + gradient_checker.ComputeGradientError(op_def, {x_info, indice_info}, {y_info}, &max_error, x_datas, {MakeAttribute("axis", axis)}); + EXPECT_IS_TINY(max_error); +} - TensorInfo y_info({2, 2}, true); - int64_t batch_dims = 2; +TEST(GradientCheckerTest, GatherNDGrad_int32_indice_unique_float_data_axis_2) { + float max_error; + GradientChecker gradient_checker; + OpDef op_def{"GatherND"}; - gradient_checker.ComputeGradientError(op_def, {x_info, indice_info}, {y_info}, &max_error, x_datas, {MakeAttribute("batch_dims", batch_dims)}); - EXPECT_IS_TINY(max_error); - } + TensorInfo x_info({2, 2, 3}, true); + TensorInfo indice_info({2, 2, 1}, false, nullptr, DataTypeImpl::GetTensorType()); + std::vector> x_datas = {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {1, 0, 2, 1}}; + + TensorInfo y_info({2, 2}, true); + int64_t axis = 2; + + gradient_checker.ComputeGradientError(op_def, {x_info, indice_info}, {y_info}, &max_error, x_datas, {MakeAttribute("axis", axis)}); + EXPECT_IS_TINY(max_error); } TEST(GradientCheckerTest, GatherElementsGradWithDuplicateUpdate) { @@ -1798,6 +1804,53 @@ TEST(GradientCheckerTest, SliceGrad) { } } +void record_event(int64_t event_id) { + OpTester test_record("RecordEvent", 1, onnxruntime::kMSDomain); + test_record.AddInput("EventIdentifier", {}, {event_id}); + test_record.AddInput("InputSignal", {}, {true}); + test_record.AddOutput("OutputSignal", {}, {true}); + test_record.Run(); +} + +void wait_event(int64_t event_id) { + OpTester test_wait("WaitEvent", 1, onnxruntime::kMSDomain); + test_wait.AddInput("EventIdentifier", {}, {event_id}); + test_wait.AddInput("InputSignal", {}, {true}); + test_wait.AddOutput("OutputSignal", {}, {true}); + test_wait.Run(); +} + +TEST(Synchronization, RecordAndWaitEvent) { + const int64_t event_id = static_cast(1736); + record_event(event_id); + wait_event(event_id); +} + +TEST(Synchronization, WaitAndRecordEvent) { + const int64_t event_id = static_cast(1228); + std::thread waiting_thread(wait_event, event_id); + std::this_thread::sleep_for(std::chrono::milliseconds(5)); + std::thread recording_thread(record_event, event_id); + + waiting_thread.join(); + recording_thread.join(); +} + +TEST(Synchronization, WaitAndRecordEventMany) { + const size_t event_count = 16; + for (int i = 0; i < 8; ++i) { + std::thread thread_pool[2 * event_count]; + for (int j = 0; j < static_cast(event_count); ++j) { + thread_pool[j] = std::thread(wait_event, j); + thread_pool[j + event_count] = std::thread(record_event, j); + } + for (size_t j = 0; j < event_count; ++j) { + thread_pool[j].join(); + thread_pool[j + event_count].join(); + } + } +} + TEST(GradientCheckerTest, ExpandGrad) { float max_error; GradientChecker gradient_checker; diff --git a/orttraining/orttraining/test/training_ops/cpu/tensor/gather_nd_grad_op_test.cc b/orttraining/orttraining/test/training_ops/cpu/tensor/gather_nd_grad_op_test.cc deleted file mode 100644 index 02c32a0827..0000000000 --- a/orttraining/orttraining/test/training_ops/cpu/tensor/gather_nd_grad_op_test.cc +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "gtest/gtest.h" -#include "test/providers/provider_test_utils.h" -#include "test/common/cuda_op_test_utils.h" -#include "test/common/tensor_op_test_utils.h" - -namespace onnxruntime { -namespace test { - -#ifdef USE_CUDA -TEST(GatherNDGradOpTest, GatherNDGrad_slice_float_int64_t_batch_dims_1) { - OpTester test("GatherNDGrad", 1, kMSDomain); - test.AddAttribute("batch_dims", 0); - test.AddInput("shape", {3}, {2LL, 2LL, 3LL}); - test.AddInput("indices", {2, 2}, {0LL, 1LL, 1LL, 0LL}); - test.AddInput("update", {2, 3}, ValueRange(6, 1.0f)); - test.AddOutput("output", {2, 2, 3}, {0, 0, 0, 1, 2, 3, 4, 5, 6, 0, 0, 0}); - test.Run(); -} - -TEST(GatherNDGradOpTest, GatherNDGrad_slice_double_int32_t_batch_dims_3) { - if (!HasCudaEnvironment(600 /*min_cuda_architecture*/)) return; - - OpTester test("GatherNDGrad", 1, kMSDomain); - test.AddAttribute("batch_dims", 1); - test.AddInput("shape", {3}, {2LL, 2LL, 3LL}); - test.AddInput("indices", {2, 1, 1}, {1LL, 0LL}); - test.AddInput("update", {2, 3}, ValueRange(6, 1.0)); - test.AddOutput("output", {2, 2, 3}, {0, 0, 0, 1, 2, 3, 4, 5, 6, 0, 0, 0}); - test.Run(); -} - - -TEST(GatherNDGradOpTest, GatherNDGrad_slice_half_int32_t_batch_dims_3) { - if (!HasCudaEnvironment(600 /*min_cuda_architecture*/)) return; - - OpTester test("GatherNDGrad", 1, kMSDomain); - test.AddAttribute("batch_dims", 1); - test.AddInput("shape", {3}, {2LL, 2LL, 3LL}); - test.AddInput("indices", {2, 1, 1}, {1LL, 0LL}); - std::vector updates_f = ValueRange(6, 1.0f); - std::vector outputs_f({0, 0, 0, 1, 2, 3, 4, 5, 6, 0, 0, 0}); - std::vector updates(6); - std::vector outputs(12); - ConvertFloatToMLFloat16(updates_f.data(), updates.data(), 6); - ConvertFloatToMLFloat16(outputs_f.data(), outputs.data(), 12); - test.AddInput("update", {2, 3}, updates); - test.AddOutput("output", {2, 2, 3}, outputs); - test.Run(); -} - -TEST(GatherNDGradOpTest, GatherNDGrad_batch_dims_of_2) { - OpTester test("GatherNDGrad", 1, kMSDomain); - test.AddAttribute("batch_dims", 2); - test.AddInput("shape", {4}, {2, 2, 2, 3}); - test.AddInput( - "indices", {2, 2, 1}, - { - 1, // batch 0 - 1, // batch 1 - 0, // batch 2 - 1, // batch 3 - }); - test.AddInput("update", {2, 2, 3}, ValueRange(12)); - test.AddOutput( - "output", {2, 2, 2, 3}, - { - 0, 0, 0, 0, 1, 2, // batch 0 - 0, 0, 0, 3, 4, 5, // batch 1 - 6, 7, 8, 0, 0, 0, // batch 2 - 0, 0, 0, 9, 10, 11, // batch 3 - }); - test.Run(); -} -#endif - -} // namespace test -} // namespace onnxruntime diff --git a/orttraining/orttraining/test/training_ops/cpu/tensor/gather_nd_op_test.cc b/orttraining/orttraining/test/training_ops/cpu/tensor/gather_nd_op_test.cc new file mode 100644 index 0000000000..0d01d27299 --- /dev/null +++ b/orttraining/orttraining/test/training_ops/cpu/tensor/gather_nd_op_test.cc @@ -0,0 +1,343 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" +#include "test/common/cuda_op_test_utils.h" + +namespace onnxruntime { +namespace test { + +namespace { +// Returns a vector of `count` values which start at `start` and change by increments of `step`. +template +std::vector ValueRange( + size_t count, T start = static_cast(0), T step = static_cast(1)) { + std::vector result; + result.reserve(count); + T curr = start; + for (size_t i = 0; i < count; ++i) { + result.emplace_back(curr); + curr += step; + } + return result; +} +} // namespace + +TEST(GatherNDOpTest, GatherND_scalar_string_int32) { + OpTester test1("GatherND", 1, onnxruntime::kOnnxDomain); + test1.AddInput("data", {2, 2}, {"h", "k", "o", "z"}); + test1.AddInput("indices", {2}, {0, 1}); + test1.AddOutput("output", {}, {"k"}); + test1.Run(); + + OpTester test2("GatherND", 1, onnxruntime::kOnnxDomain); + test2.AddInput("data", {6}, {"h", "k", "o", "z", "l", "t"}); + test2.AddInput("indices", {1}, {3}); + test2.AddOutput("output", {}, {"z"}); + test2.Run(); + + OpTester test3("GatherND", 1, onnxruntime::kOnnxDomain); + test3.AddInput("data", {3, 2}, {"h", "k", "o", "z", "l", "t"}); + test3.AddInput("indices", {2}, {2, 1}); + test3.AddOutput("output", {}, {"t"}); + test3.Run(); +} + +TEST(GatherNDOpTest, GatherND_matrix_int64_int64) { + OpTester test("GatherND", 1, onnxruntime::kOnnxDomain); + test.AddInput("data", {2, 2}, {0LL, 1LL, 2LL, 3LL}); + test.AddInput("indices", {2, 2}, {0LL, 0LL, 1LL, 1LL}); + test.AddOutput("output", {2}, {0LL, 3LL}); + test.Run(); +} + +TEST(GatherNDOpTest, GatherND_matrix_string_int64) { + OpTester test("GatherND", 1, onnxruntime::kOnnxDomain); + test.AddInput("data", {2, 2}, {"a", "b", "c", "d"}); + test.AddInput("indices", {2, 2}, {0LL, 0LL, 1LL, 1LL}); + test.AddOutput("output", {2}, {"a", "d"}); + test.Run(); +} + +TEST(GatherNDOpTest, GatherND_matrix_int64_int32) { + OpTester test("GatherND", 1, onnxruntime::kOnnxDomain); + test.AddInput("data", {2, 2}, {0LL, 1LL, 2LL, 3LL}); + test.AddInput("indices", {2, 2}, {0, 0, 1, 1}); + test.AddOutput("output", {2}, {0LL, 3LL}); + test.Run(); +} + +TEST(GatherNDOpTest, GatherND_matrix_string_int32) { + OpTester test1("GatherND", 1, onnxruntime::kOnnxDomain); + test1.AddInput("data", {2, 2, 2}, {"egg", "dance", "air", "bob", "terry", "smart", "laugh", "kite"}); + test1.AddInput("indices", {2, 1, 2}, {0, 1, 1, 0}); + test1.AddOutput("output", {2, 1, 2}, {"air", "bob", "terry", "smart"}); + test1.Run(); + + OpTester test2("GatherND", 1, onnxruntime::kOnnxDomain); + test2.AddInput("data", {3, 3}, {"egg", "dance", "air", "bob", "terry", "smart", "laugh", "kite", "hop"}); + test2.AddInput("indices", {3, 2}, {2, 1, 1, 0, 0, 1}); + test2.AddOutput("output", {3}, {"kite", "bob", "dance"}); + test2.Run(); +} + +TEST(GatherNDOpTest, GatherND_slice_float_int64_t) { + OpTester test("GatherND", 1, onnxruntime::kOnnxDomain); + test.AddInput("data", {2, 2}, {0.0f, 0.1f, 0.2f, 0.3f}); + test.AddInput("indices", {2, 1}, {1LL, 0LL}); + test.AddOutput("output", {2, 2}, {0.2f, 0.3f, 0.0f, 0.1f}); + test.Run(); +} + +TEST(GatherNDOpTest, GatherND_slice_float_int64_t_axis_0) { + OpTester test("GatherND", 1, onnxruntime::kOnnxDomain); + test.AddAttribute("axis", 0); + test.AddInput("data", {2, 3, 4}, ValueRange(24, 1.0f)); + test.AddInput("indices", {3, 2, 2}, {0LL, 1LL, 0LL, 2LL, 1LL, 0LL, 0LL, 0LL, 1LL, 1LL, 1LL, 2LL}); + test.AddOutput("output", {3, 2, 4}, {5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 1.0, 2.0, 3.0, 4.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}); + test.Run(); +} + +TEST(GatherNDOpTest, GatherND_slice_float_int64_t_axis_1) { + OpTester test("GatherND", 1, onnxruntime::kOnnxDomain); + test.AddAttribute("axis", 1); + test.AddInput("data", {2, 3, 4}, ValueRange(24, 1.0f)); + test.AddInput("indices", {2, 2, 2}, {0LL, 1LL, 0LL, 2LL, 1LL, 0LL, 0LL, 0LL}); + test.AddOutput("output", {2, 2}, {2.0, 3.0, 17.0, 13.0}); + test.Run(); +} + +TEST(GatherNDOpTest, GatherND_slice_float_int32_t_axis_2) { + OpTester test("GatherND", 1, onnxruntime::kOnnxDomain); + test.AddAttribute("axis", 1); + test.AddInput("data", {2, 2, 2}, ValueRange(8, 0.0f, 0.1f)); + test.AddInput("indices", {2, 1}, {1LL, 0LL}); + test.AddOutput("output", {2, 2}, {0.2f, 0.3f, 0.4f, 0.5f}); + test.Run(); +} + +#ifdef USE_CUDA +#if __CUDA_ARCH__ >= 600 +TEST(GatherNDOpTest, GatherND_slice_double_int64_t_axis_3) { + OpTester test("GatherND", 1, onnxruntime::kOnnxDomain); + test.AddAttribute("axis", 1); + test.AddInput("data", {2, 2, 2}, ValueRange(8, 0.0f, 0.1f)); + test.AddInput("indices", {2, 1, 1}, {1LL, 0LL}); + test.AddOutput("output", {2, 1, 2}, {0.2f, 0.3f, 0.4f, 0.5f}); + test.Run(); +} + +TEST(GatherNDOpTest, GatherND_slice_double_int32_t) { + OpTester test("GatherND", 1, onnxruntime::kOnnxDomain); + test.AddInput("data", {2, 2}, {0.0f, 0.1f, 0.2f, 0.3f}); + test.AddInput("indices", {2, 1}, {1LL, 0LL}); + test.AddOutput("output", {2, 2}, {0.2f, 0.3f, 0.0f, 0.1f}); + test.Run(); +} +#endif +#endif + +TEST(GatherNDOpTest, GatherND_slice_float_int64_t_axis_4) { + OpTester test("GatherND", 1, onnxruntime::kOnnxDomain); + test.AddAttribute("axis", 1); + test.AddInput("data", {2, 2, 2}, ValueRange(8, 0.0f, 0.1f)); + test.AddInput("indices", {2, 1, 2}, {1LL, 0LL, 0LL, 1LL}); + test.AddOutput("output", {2, 1}, {0.2f, 0.5f}); + test.Run(); +} + +TEST(GatherNDOpTest, GatherND_3tensor_int64) { + OpTester test1("GatherND", 1, onnxruntime::kOnnxDomain); + test1.AddInput("data", {2, 2, 2}, ValueRange(8)); + test1.AddInput("indices", {2, 2}, {0LL, 1LL, 1LL, 0LL}); + test1.AddOutput("output", {2, 2}, {2LL, 3LL, 4LL, 5LL}); + test1.Run(); + + OpTester test2("GatherND", 1, onnxruntime::kOnnxDomain); + test2.AddInput("data", {2, 2, 2}, ValueRange(8)); + test2.AddInput("indices", {2, 3}, {0, 0, 1, 1, 0, 1}); + test2.AddOutput("output", {2}, {1, 5}); + test2.Run(); + + OpTester test3("GatherND", 1, onnxruntime::kOnnxDomain); + test3.AddInput("data", {2, 2, 2}, ValueRange(8)); + test3.AddInput("indices", {1, 1}, {1LL}); + test3.AddOutput("output", {1, 2, 2}, {4, 5, 6, 7}); + test3.Run(); +} + +TEST(GatherNDOpTest, GatherND_batched_index_int64) { + OpTester test("GatherND", 1, onnxruntime::kOnnxDomain); + test.AddInput("data", {2, 2}, {0LL, 1LL, 2LL, 3LL}); + test.AddInput("indices", {2, 1, 2}, {0LL, 0LL, 0LL, 1LL}); + test.AddOutput("output", {2, 1}, {0LL, 1LL}); + test.Run(); +} + +TEST(GatherNDOpTest, GatherND_batched_index_bool_int64) { + OpTester test("GatherND", 1, onnxruntime::kOnnxDomain); + test.AddInput("data", {2, 2}, {true, false, false, true}); + test.AddInput("indices", {2, 1, 2}, {0LL, 0LL, 0LL, 1LL}); + test.AddOutput("output", {2, 1}, {true, false}); + test.Run(); +} + +TEST(GatherNDOpTest, GatherND_sliced_index_int64) { + OpTester test("GatherND", 1, onnxruntime::kOnnxDomain); + test.AddInput("data", {2, 2}, {0LL, 1LL, 2LL, 3LL}); + test.AddInput("indices", {2, 1, 1}, {1LL, 0LL}); + test.AddOutput("output", {2, 1, 2}, {2LL, 3LL, 0LL, 1LL}); + test.Run(); +} + +TEST(GatherNDOpTest, GatherND_sliced_index_string_int32) { + OpTester test("GatherND", 1, onnxruntime::kOnnxDomain); + test.AddInput("data", {2, 2}, {"ab", "cde", "f", "ghi"}); + test.AddInput("indices", {2, 1, 1}, {1LL, 0LL}); + test.AddOutput("output", {2, 1, 2}, {"f", "ghi", "ab", "cde"}); + test.Run(); +} + +TEST(GatherNDOpTest, GatherND_batched_3tensor_int64) { + OpTester test1("GatherND", 1, onnxruntime::kOnnxDomain); + test1.AddInput("data", {2, 2, 2}, ValueRange(8)); + test1.AddInput("indices", {2, 2, 2}, {0LL, 1LL, 1LL, 0LL, 0LL, 0LL, 1LL, 1LL}); + test1.AddOutput("output", {2, 2, 2}, {2, 3, 4, 5, 0, 1, 6, 7}); + test1.Run(); + + OpTester test2("GatherND", 1, onnxruntime::kOnnxDomain); + test2.AddInput("data", {2, 2, 2}, ValueRange(8)); + test2.AddInput("indices", {2, 2, 3}, {0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0}); + test2.AddOutput("output", {2, 2}, {1, 5, 3, 6}); + test2.Run(); + + OpTester test3("GatherND", 1, onnxruntime::kOnnxDomain); + test3.AddInput("data", {2, 2, 2}, ValueRange(8)); + test3.AddInput("indices", {2, 1, 1}, {1, 0}); + test3.AddOutput("output", {2, 1, 2, 2}, {4LL, 5LL, 6LL, 7LL, 0LL, 1LL, 2LL, 3LL}); + test3.Run(); +} + +#ifdef USE_CUDA +TEST(GatherNDOpTest, GatherNDGrad_slice_float_int64_t_axis_1) { + OpTester test("GatherNDGrad", 1, onnxruntime::kOnnxDomain); + test.AddAttribute("axis", 0); + test.AddInput("shape", {3}, {2LL, 2LL, 3LL}); + test.AddInput("indices", {2, 2}, {0LL, 1LL, 1LL, 0LL}); + test.AddInput("update", {2, 3}, ValueRange(6, 1.0f)); + test.AddOutput("output", {2, 2, 3}, {0, 0, 0, 1, 2, 3, 4, 5, 6, 0, 0, 0}); + test.Run(); +} +#endif + +#ifdef USE_CUDA +TEST(GatherNDOpTest, GatherNDGrad_slice_double_int32_t_axis_3) { + if (!HasCudaEnvironment(600 /*min_cuda_architecture*/)) return; + + OpTester test("GatherNDGrad", 1, onnxruntime::kOnnxDomain); + test.AddAttribute("axis", 1); + test.AddInput("shape", {3}, {2LL, 2LL, 3LL}); + test.AddInput("indices", {2, 1, 1}, {1LL, 0LL}); + test.AddInput("update", {2, 3}, ValueRange(6, 1.0)); + test.AddOutput("output", {2, 2, 3}, {0, 0, 0, 1, 2, 3, 4, 5, 6, 0, 0, 0}); + test.Run(); +} + +TEST(GatherNDOpTest, GatherND_slice_double_int64_t_axis_3) { + if (!HasCudaEnvironment(600 /*min_cuda_architecture*/)) return; + + OpTester test("GatherND", 1, onnxruntime::kOnnxDomain); + test.AddAttribute("axis", 1); + test.AddInput("data", {2, 2, 2}, ValueRange(8, 0.0, 0.1)); + test.AddInput("indices", {2, 1, 1}, {1LL, 0LL}); + test.AddOutput("output", {2, 1, 2}, {0.2f, 0.3f, 0.4f, 0.5f}); + test.Run(); +} + +TEST(GatherNDOpTest, GatherNDGrad_slice_half_int32_t_axis_3) { + if (!HasCudaEnvironment(600 /*min_cuda_architecture*/)) return; + + OpTester test("GatherNDGrad", 1, onnxruntime::kOnnxDomain); + test.AddAttribute("axis", 1); + test.AddInput("shape", {3}, {2LL, 2LL, 3LL}); + test.AddInput("indices", {2, 1, 1}, {1LL, 0LL}); + std::vector updates_f = ValueRange(6, 1.0f); + std::vector outputs_f({0, 0, 0, 1, 2, 3, 4, 5, 6, 0, 0, 0}); + std::vector updates(6); + std::vector outputs(12); + ConvertFloatToMLFloat16(updates_f.data(), updates.data(), 6); + ConvertFloatToMLFloat16(outputs_f.data(), outputs.data(), 12); + test.AddInput("update", {2, 3}, updates); + test.AddOutput("output", {2, 2, 3}, outputs); + test.Run(); +} + +TEST(GatherNDOpTest, GatherND_slice_half_int32_t) { + if (!HasCudaEnvironment(600 /*min_cuda_architecture*/)) return; + + OpTester test("GatherND", 1, onnxruntime::kOnnxDomain); + std::vector data_f({0.0f, 0.1f, 0.2f, 0.3f}); + std::vector outputs_f({0.2f, 0.3f, 0.0f, 0.1f}); + std::vector data(4); + std::vector outputs(4); + ConvertFloatToMLFloat16(data_f.data(), data.data(), 4); + ConvertFloatToMLFloat16(outputs_f.data(), outputs.data(), 4); + test.AddInput("data", {2, 2}, data); + test.AddInput("indices", {2, 1}, {1LL, 0LL}); + test.AddOutput("output", {2, 2}, outputs); + test.Run(); +} +#endif + +#ifdef USE_CUDA +TEST(GatherNDOpTest, GatherND_axis_of_2) { + OpTester test("GatherND", 1, kOnnxDomain); + test.AddAttribute("axis", 2); + test.AddInput("data", {2, 2, 2, 2, 3}, ValueRange(48)); + test.AddInput( + "indices", {2, 2, 1, 2}, + { + 0, 0, // batch 0 + 1, 0, // batch 1 + 1, 1, // batch 2 + 0, 1, // batch 3 + }); + test.AddOutput( + "output", {2, 2, 1, 3}, + { + 0, 1, 2, // batch 0 + 18, 19, 20, // batch 1 + 33, 34, 35, // batch 2 + 39, 40, 41, // batch 3 + }); + test.Run(); +} + +TEST(GatherNDOpTest, GatherNDGrad_axis_of_2) { + OpTester test("GatherNDGrad", 1, kOnnxDomain); + test.AddAttribute("axis", 2); + test.AddInput("shape", {4}, {2, 2, 2, 3}); + test.AddInput( + "indices", {2, 2, 1}, + { + 1, // batch 0 + 1, // batch 1 + 0, // batch 2 + 1, // batch 3 + }); + test.AddInput("update", {2, 2, 3}, ValueRange(12)); + test.AddOutput( + "output", {2, 2, 2, 3}, + { + 0, 0, 0, 0, 1, 2, // batch 0 + 0, 0, 0, 3, 4, 5, // batch 1 + 6, 7, 8, 0, 0, 0, // batch 2 + 0, 0, 0, 9, 10, 11, // batch 3 + }); + test.Run(); +} +#endif + +} // namespace test +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cpu/controlflow/common.h b/orttraining/orttraining/training_ops/cpu/controlflow/common.h deleted file mode 100644 index c6ff14689f..0000000000 --- a/orttraining/orttraining/training_ops/cpu/controlflow/common.h +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -namespace onnxruntime { -namespace contrib { - -template -std::vector> AliasRange(int start, int end) { - std::vector> aliases; - for (int i = start; i < end; i++) { - aliases.push_back(std::pair(input_start + i, output_start + i)); - } - return aliases; -} - -} // namespace contrib -} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cpu/controlflow/event_pool.cc b/orttraining/orttraining/training_ops/cpu/controlflow/event_pool.cc index e9bb0c51c3..a39ab399c2 100644 --- a/orttraining/orttraining/training_ops/cpu/controlflow/event_pool.cc +++ b/orttraining/orttraining/training_ops/cpu/controlflow/event_pool.cc @@ -6,16 +6,8 @@ namespace onnxruntime { namespace contrib { -void OrtEventPool::CheckRange(const int64_t id) const { - ORT_ENFORCE( - id >= 0 && id < MaxNumItems, - "Got id ", id, - ". It should be in a range from 0 to ", - MaxNumItems, "."); -} - void OrtEventPool::SignalEvent(int64_t id) { - CheckRange(id); + ORT_ENFORCE(id >= 0 && id < MaxNumItems); std::unique_lock lock(pool_[id].mutex); pool_[id].signaled.store(true); lock.unlock(); @@ -23,18 +15,18 @@ void OrtEventPool::SignalEvent(int64_t id) { }; void OrtEventPool::ResetEvent(int64_t id) { - CheckRange(id); + ORT_ENFORCE(id >= 0 && id < MaxNumItems); std::lock_guard guard(pool_[id].mutex); pool_[id].signaled.store(false); }; bool OrtEventPool::QueryEvent(int64_t id) const { - CheckRange(id); + ORT_ENFORCE(id >= 0 && id < MaxNumItems); return pool_[id].signaled.load(); } void OrtEventPool::WaitEvent(int64_t id) const { - CheckRange(id); + ORT_ENFORCE(id >= 0 && id < MaxNumItems); std::unique_lock lock(pool_[id].mutex); pool_[id].cv.wait(lock, [this, id] { return pool_[id].signaled.load(); }); }; diff --git a/orttraining/orttraining/training_ops/cpu/controlflow/event_pool.h b/orttraining/orttraining/training_ops/cpu/controlflow/event_pool.h index 68e9b95abb..511a5da234 100644 --- a/orttraining/orttraining/training_ops/cpu/controlflow/event_pool.h +++ b/orttraining/orttraining/training_ops/cpu/controlflow/event_pool.h @@ -34,8 +34,6 @@ class OrtEventPool final { OrtEventPool(const OrtEventPool&) = delete; OrtEventPool& operator=(const OrtEventPool&) = delete; - void CheckRange(const int64_t event_id) const; - struct Item { std::atomic signaled; mutable std::mutex mutex; @@ -45,11 +43,9 @@ class OrtEventPool final { signaled.store(false); } }; - enum { MaxNumItems = 4096 }; - Item pool_[MaxNumItems]; }; diff --git a/orttraining/orttraining/training_ops/cpu/controlflow/record.cc b/orttraining/orttraining/training_ops/cpu/controlflow/record.cc index eb305a7c27..c36a23e9e7 100644 --- a/orttraining/orttraining/training_ops/cpu/controlflow/record.cc +++ b/orttraining/orttraining/training_ops/cpu/controlflow/record.cc @@ -1,17 +1,19 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "orttraining/training_ops/cpu/controlflow/record.h" +#include "record.h" #include "core/providers/cpu/tensor/utils.h" -#include "common.h" namespace onnxruntime { namespace contrib { -void record_event_in_tensor(const Tensor& event_id_tensor) { - const int64_t event_id = *event_id_tensor.template Data(); - ORT_ENFORCE(event_id != -1, "-1 is reserved for skip wait, so cannot be used in RecordEvent"); - OrtEventPool::GetInstance().SignalEvent(event_id); +template +std::vector> AliasRange(int start, int end) { + std::vector> aliases; + for (int i = start; i < end; i++) { + aliases.push_back(std::pair(input_start + i, output_start + i)); + } + return aliases; } ONNX_OPERATOR_KERNEL_EX( @@ -26,8 +28,12 @@ ONNX_OPERATOR_KERNEL_EX( RecordEvent); Status RecordEvent::Compute(OpKernelContext* ctx) const { - // Pass event-id tensor to event-recording helper function. - record_event_in_tensor(*ctx->Input(0)); + const Tensor* event_id_tensor = ctx->Input(0); + const int64_t event_id = *event_id_tensor->template Data(); + + ORT_RETURN_IF_NOT(event_id != -1, "-1 is reserved for skip wait, so cannot be used in RecordEvent"); + + OrtEventPool::GetInstance().SignalEvent(event_id); for (int i_out = 0; i_out < ctx->OutputCount(); ++i_out) { const Tensor* X = ctx->Input(i_out + 1); diff --git a/orttraining/orttraining/training_ops/cpu/controlflow/record.h b/orttraining/orttraining/training_ops/cpu/controlflow/record.h index 61fb28abba..d4f02a612d 100644 --- a/orttraining/orttraining/training_ops/cpu/controlflow/record.h +++ b/orttraining/orttraining/training_ops/cpu/controlflow/record.h @@ -9,9 +9,6 @@ namespace onnxruntime { namespace contrib { -// Record the event ID stored in the input tensor. -void record_event_in_tensor(const Tensor& event_id_tensor); - class RecordEvent final : public OpKernel { public: RecordEvent(const OpKernelInfo& info) : OpKernel(info) { } diff --git a/orttraining/orttraining/training_ops/cpu/controlflow/wait.cc b/orttraining/orttraining/training_ops/cpu/controlflow/wait.cc index 67f3128234..b08ba65ddd 100644 --- a/orttraining/orttraining/training_ops/cpu/controlflow/wait.cc +++ b/orttraining/orttraining/training_ops/cpu/controlflow/wait.cc @@ -1,23 +1,19 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "orttraining/training_ops/cpu/controlflow/wait.h" +#include "wait.h" #include "core/providers/cpu/tensor/utils.h" -#include "common.h" namespace onnxruntime { namespace contrib { -void wait_event_in_tensor(const Tensor& event_id_tensor) { - const int64_t event_id = *event_id_tensor.template Data(); - // -1 is reserved to skip wait event - if (event_id != -1) { - // Wait the event to be recorded by a RecordEvent operator. - OrtEventPool::GetInstance().WaitEvent(event_id); - // BUGBUG: seems this would cause hang when a event is being waited more than once - // Destory the recorded event. - OrtEventPool::GetInstance().ResetEvent(event_id); +template +std::vector> AliasRange(int start, int end) { + std::vector> aliases; + for (int i = start; i < end; i++) { + aliases.push_back(std::pair(input_start + i, output_start + i)); } + return aliases; } ONNX_OPERATOR_KERNEL_EX( @@ -32,7 +28,18 @@ ONNX_OPERATOR_KERNEL_EX( WaitEvent); Status WaitEvent::Compute(OpKernelContext* ctx) const { - wait_event_in_tensor(*ctx->Input(0)); + const Tensor* event_id_tensor = ctx->Input(0); + const int64_t event_id = *event_id_tensor->template Data(); + + // -1 is reserved to skip wait event + if (event_id != -1) { + // Wait the event to be recorded by a RecordEvent operator. + OrtEventPool::GetInstance().WaitEvent(event_id); + + // BUGBUG: seems this would cause hang when a event is being waited more than once + // Destory the recorded event. + OrtEventPool::GetInstance().ResetEvent(event_id); + } for (int i_out = 0; i_out < ctx->OutputCount(); ++i_out) { const Tensor* X = ctx->Input(i_out + 1); diff --git a/orttraining/orttraining/training_ops/cpu/controlflow/wait.h b/orttraining/orttraining/training_ops/cpu/controlflow/wait.h index 682aa4388e..dff514880f 100644 --- a/orttraining/orttraining/training_ops/cpu/controlflow/wait.h +++ b/orttraining/orttraining/training_ops/cpu/controlflow/wait.h @@ -2,6 +2,8 @@ // Licensed under the MIT License. #pragma once +#include +#include #include "core/common/common.h" #include "core/framework/op_kernel.h" #include "event_pool.h" @@ -9,9 +11,6 @@ namespace onnxruntime { namespace contrib { -// Wait for the event ID stored in the input tensor. -void wait_event_in_tensor(const Tensor& event_id_tensor); - class WaitEvent final : public OpKernel { public: WaitEvent(const OpKernelInfo& info) : OpKernel(info) { } diff --git a/orttraining/orttraining/training_ops/cpu/tensor/gather_nd.cc b/orttraining/orttraining/training_ops/cpu/tensor/gather_nd.cc new file mode 100644 index 0000000000..6e11910ea2 --- /dev/null +++ b/orttraining/orttraining/training_ops/cpu/tensor/gather_nd.cc @@ -0,0 +1,138 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/training_ops/cpu/tensor/gather_nd.h" + +namespace onnxruntime { +namespace contrib { + +ONNX_OPERATOR_KERNEL_EX( + GatherND, + kOnnxDomain, + 1, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::AllTensorTypes()) + .TypeConstraint("Tind", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), + GatherND); + +template +Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) const { + auto input_tensor = context->Input(0); + auto indice_tensor = context->Input(1); + ORT_ENFORCE(input_tensor != nullptr); + ORT_ENFORCE(indice_tensor != nullptr); + + auto input_shape = input_tensor->Shape(); + auto indice_shape = indice_tensor->Shape(); + if (indice_shape.NumDimensions() == 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "indices tensor must has rank larger than 0"); + } + + auto last_indice_dimension = indice_shape[indice_shape.NumDimensions() - 1] + axis_; + if (last_indice_dimension > static_cast(input_shape.NumDimensions())) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "last dimension of indices must not be larger than rank of input tensor"); + } + + std::vector shape(indice_shape.GetDims().begin(), + indice_shape.GetDims().end() - 1); + shape.insert(shape.end(), + input_shape.GetDims().begin() + last_indice_dimension, + input_shape.GetDims().end()); + auto output_tensor = context->Output(0, TensorShape(shape)); + std::vector element_counts(last_indice_dimension + axis_, 0LL); // Number of elements for each input dimension + +#ifdef USE_OPENMP +#pragma omp parallel for +#endif + for (int64_t i = 0; i < last_indice_dimension; ++i) { + element_counts[i] = input_shape.SizeFromDimension(i + 1); + } + + auto last_dim_size = indice_shape.SizeFromDimension(indice_shape.NumDimensions() - 1); +#ifdef USE_OPENMP +#pragma omp parallel for +#endif + for (int64_t i = axis_ - 1; i >= 0; --i) { + element_counts[last_indice_dimension + i] = indice_shape.SizeFromDimension(i + 1) / last_dim_size; + } + + int64_t err_indice = 0; + p.element_bytes = input_tensor->DataType()->Size(); + p.element_to_copy = input_shape.SizeFromDimension(last_indice_dimension); + p.bytes_to_copy = p.element_bytes * p.element_to_copy; + auto indice_offset = indice_tensor->Data(); + auto offset_count = indice_shape.Size() / (last_indice_dimension - axis_); // Times to copy + p.element_offsets.assign(offset_count, 0LL); + + if (input_tensor->DataType() == DataTypeImpl::GetType()) { + p.input_str_base = static_cast(input_tensor->DataRaw()); + p.output_str_base = static_cast(output_tensor->MutableDataRaw()); + } else { + p.input_base = static_cast(input_tensor->DataRaw()); + p.output_base = static_cast(output_tensor->MutableDataRaw()); + } + + //Compute the element_offset +#ifdef USE_OPENMP +#pragma omp parallel for +#endif + for (int64_t i = 0; i < offset_count; ++i) { + int64_t reminder = i; + for (int64_t j = 0; j < axis_; ++j) { + int64_t idx = reminder / element_counts[last_indice_dimension + j]; + p.element_offsets[i] += idx * element_counts[j]; + reminder -= (idx * element_counts[last_indice_dimension + j]); + } + for (int64_t j = axis_; j < last_indice_dimension; ++j) { + auto indice = *(indice_offset + i * (last_indice_dimension - axis_) + (j - axis_)); + if (indice < 0 || indice >= input_shape[j]) { + err_indice = indice; + } + p.element_offsets[i] += indice * element_counts[j]; + } + } + + return err_indice == 0 ? Status::OK() : ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid indice found, indice = ", err_indice); +} + +template Status GatherNDBase::PrepareForCompute(OpKernelContext*, Prepare&) const; +template Status GatherNDBase::PrepareForCompute(OpKernelContext*, Prepare&) const; + +Status GatherND::Compute(OpKernelContext* context) const { + Prepare p; + ORT_RETURN_IF_ERROR(context->Input(1)->DataType() == DataTypeImpl::GetType() ? PrepareForCompute(context, p) : PrepareForCompute(context, p)); + + return nullptr == p.input_str_base ? GatherNumber(p) : GatherString(p); +} + +Status GatherND::GatherNumber(const Prepare& p) const { +#ifdef USE_OPENMP +#pragma omp parallel for +#endif + for (int64_t i = 0; i < static_cast(p.element_offsets.size()); ++i) { + memcpy(p.output_base + i * p.bytes_to_copy, + p.input_base + p.element_offsets[i] * p.element_bytes, + p.bytes_to_copy); + } + + return Status::OK(); +} + +Status GatherND::GatherString(const Prepare& p) const { +#ifdef USE_OPENMP +#pragma omp parallel for +#endif + for (int64_t i = 0; i < static_cast(p.element_offsets.size()); ++i) { + for (int64_t j = 0; j < static_cast(p.element_to_copy); ++j) { + p.output_str_base[i * p.element_to_copy + j] = p.input_str_base[p.element_offsets[i] + j]; + } + } + + return Status::OK(); +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cpu/tensor/gather_nd.h b/orttraining/orttraining/training_ops/cpu/tensor/gather_nd.h new file mode 100644 index 0000000000..2b1caddad9 --- /dev/null +++ b/orttraining/orttraining/training_ops/cpu/tensor/gather_nd.h @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/platform/threadpool.h" + +namespace onnxruntime { +namespace contrib { + +class GatherNDBase { + protected: + struct Prepare { + const uint8_t* input_base; + const std::string* input_str_base; + uint8_t* output_base; + std::string* output_str_base; + uint64_t bytes_to_copy; + uint64_t element_bytes; + uint64_t element_to_copy; + std::vector element_offsets; + + Prepare() : input_base(nullptr), + input_str_base(nullptr), + output_base(nullptr), + output_str_base(nullptr), + bytes_to_copy(0), + element_bytes(0), + element_to_copy(0), + element_offsets(0) {} + }; // struct Prepare + + template + Status PrepareForCompute(OpKernelContext* context, Prepare& p) const; + int64_t axis_; +}; // class GatherNDBase + +class GatherND final : public OpKernel, protected GatherNDBase { + public: + explicit GatherND(const OpKernelInfo& info) : OpKernel(info) { + info.GetAttrOrDefault("axis", &axis_, static_cast(0)); + } + Status Compute(OpKernelContext* context) const override; + + private: + Status GatherNumber(const Prepare& p) const; + Status GatherString(const Prepare& p) const; +}; + +} // namespace contrib +} // namespace onnxruntime \ No newline at end of file diff --git a/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc b/orttraining/orttraining/training_ops/cpu_training_kernels.cc similarity index 98% rename from orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc rename to orttraining/orttraining/training_ops/cpu_training_kernels.cc index c8783baff9..edde3947dc 100644 --- a/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cpu_training_kernels.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "orttraining/training_ops/cpu/cpu_training_kernels.h" +#include "orttraining/training_ops/cpu_training_kernels.h" #include "core/graph/constants.h" namespace onnxruntime { @@ -17,6 +17,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, InPla class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ZeroGradient); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Group); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, GatherND); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SoftmaxCrossEntropy); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SoftmaxCrossEntropyGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, SparseSoftmaxCrossEntropy); @@ -105,6 +106,7 @@ Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.h b/orttraining/orttraining/training_ops/cpu_training_kernels.h similarity index 100% rename from orttraining/orttraining/training_ops/cpu/cpu_training_kernels.h rename to orttraining/orttraining/training_ops/cpu_training_kernels.h diff --git a/orttraining/orttraining/training_ops/cuda/controlflow/record.cc b/orttraining/orttraining/training_ops/cuda/controlflow/record.cc deleted file mode 100644 index 2f8bb82a2e..0000000000 --- a/orttraining/orttraining/training_ops/cuda/controlflow/record.cc +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "orttraining/training_ops/cuda/controlflow/record.h" -#include "core/providers/cpu/tensor/utils.h" -// Include RecordEvent's utility functions shared by CPU and GPU implementations. -#include "orttraining/training_ops/cpu/controlflow/common.h" -// Include event mechanism shared by CPU and GPU implementations. -#include "orttraining/training_ops/cpu/controlflow/event_pool.h" -#include "orttraining/training_ops/cpu/controlflow/record.h" - -namespace onnxruntime { -namespace cuda { - -ONNX_OPERATOR_KERNEL_EX( - RecordEvent, - kMSDomain, - 1, - kCudaExecutionProvider, - KernelDefBuilder() - .InputMemoryType(0) /* Keep EventIdentifier in CPU */ - .TypeConstraint("TInt64", DataTypeImpl::GetTensorType()) - .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) - .Alias(onnxruntime::contrib::AliasRange<1, 0>(0, 1024)), - RecordEvent); - -Status RecordEvent::ComputeInternal(OpKernelContext* ctx) const { - // Reuse CPU helper to record event because event tensor is a CPU tensor. - onnxruntime::contrib::record_event_in_tensor(*ctx->Input(0)); - - for (int i_out = 0; i_out < ctx->OutputCount(); ++i_out) { - // This iteration copies (i-1)-th input to i-th output. - const Tensor* X = ctx->Input(i_out + 1); - const TensorShape& data_shape = X->Shape(); - Tensor* Y = ctx->Output(i_out, data_shape); - CopyTensor(*X, *Y); - } - - return Status::OK(); -} - -} // namespace cuda -} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/controlflow/record.h b/orttraining/orttraining/training_ops/cuda/controlflow/record.h deleted file mode 100644 index 0063af48f1..0000000000 --- a/orttraining/orttraining/training_ops/cuda/controlflow/record.h +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/common/common.h" -#include "core/providers/cuda/cuda_common.h" -#include "core/providers/cuda/cudnn_common.h" - -namespace onnxruntime { -namespace cuda { - -class RecordEvent final : public CudaKernel { -public: - RecordEvent(const OpKernelInfo& info) : CudaKernel(info) { } - Status ComputeInternal(OpKernelContext* context) const override; -}; - -} // namespace cuda -} // namespace onnxruntime \ No newline at end of file diff --git a/orttraining/orttraining/training_ops/cuda/controlflow/wait.cc b/orttraining/orttraining/training_ops/cuda/controlflow/wait.cc deleted file mode 100644 index f90177c51a..0000000000 --- a/orttraining/orttraining/training_ops/cuda/controlflow/wait.cc +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "orttraining/training_ops/cuda/controlflow/wait.h" -#include "core/providers/cpu/tensor/utils.h" -// Include RecordEvent's utility functions shared by CPU and GPU implementations. -#include "orttraining/training_ops/cpu/controlflow/common.h" -// Include event mechanism shared by CPU and GPU implementations. -#include "orttraining/training_ops/cpu/controlflow/event_pool.h" -#include "orttraining/training_ops/cpu/controlflow/wait.h" - -namespace onnxruntime { -namespace cuda { - -ONNX_OPERATOR_KERNEL_EX( - WaitEvent, - kMSDomain, - 1, - kCudaExecutionProvider, - KernelDefBuilder() - .InputMemoryType(0) /* CPU variable */ - .TypeConstraint("TInt64", DataTypeImpl::GetTensorType()) - .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) - .Alias(onnxruntime::contrib::AliasRange<1, 0>(0, 1024)), - WaitEvent); - -Status WaitEvent::ComputeInternal(OpKernelContext* ctx) const { - // Reuse CPU helper to wait event because event tensor is a CPU tensor. - onnxruntime::contrib::wait_event_in_tensor(*ctx->Input(0)); - - for (int i_out = 0; i_out < ctx->OutputCount(); ++i_out) { - // This iteration copies (i-1)-th input to i-th output. - const Tensor* X = ctx->Input(i_out + 1); - const TensorShape& data_shape = X->Shape(); - Tensor* Y = ctx->Output(i_out, data_shape); - CopyTensor(*X, *Y); - } - - return Status::OK(); -} - -} // namespace cuda -} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/controlflow/wait.h b/orttraining/orttraining/training_ops/cuda/controlflow/wait.h deleted file mode 100644 index b4a687fef6..0000000000 --- a/orttraining/orttraining/training_ops/cuda/controlflow/wait.h +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/common/common.h" -#include "core/providers/cuda/cuda_common.h" -#include "core/providers/cuda/cudnn_common.h" - -namespace onnxruntime { -namespace cuda { - -class WaitEvent final : public CudaKernel { -public: - WaitEvent(const OpKernelInfo& info) : CudaKernel(info) { } - Status ComputeInternal(OpKernelContext* context) const override; -}; - -} // namespace cuda -} // namespace onnxruntime \ No newline at end of file diff --git a/orttraining/orttraining/training_ops/cuda/nn/dropout.cc b/orttraining/orttraining/training_ops/cuda/nn/dropout.cc index 44e39f8af3..bced973f88 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/dropout.cc +++ b/orttraining/orttraining/training_ops/cuda/nn/dropout.cc @@ -37,6 +37,7 @@ REGISTER_KERNEL_TYPED(Dropout, kOnnxDomain, 12, double, double, 1) template Status Dropout::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; + typedef typename ToCudaType::MappedType CudaT2; //Get X_data const Tensor* X = context->Input(0); @@ -67,7 +68,7 @@ Status Dropout::ComputeInternal(OpKernelContext* context) const { "T2 must be float16 or float or double"); if (ratio) { - ratio_data = static_cast(*(ratio->template Data())); + ratio_data = static_cast(*reinterpret_cast(ratio->template Data())); } else { ratio_data = default_ratio_; } @@ -111,7 +112,7 @@ Status DropoutGrad::ComputeInternal(OpKernelContext* context) const { "T2 must be float16 or float or double"); if (ratio) { - ratio_data = static_cast(*(ratio->template Data())); + ratio_data = static_cast(*reinterpret_cast(ratio->template Data())); } else { ratio_data = default_ratio_; } diff --git a/orttraining/orttraining/training_ops/cuda/tensor/gather_grad.cc b/orttraining/orttraining/training_ops/cuda/tensor/gather_grad.cc index 56f0e4e984..70c9c58247 100644 --- a/orttraining/orttraining/training_ops/cuda/tensor/gather_grad.cc +++ b/orttraining/orttraining/training_ops/cuda/tensor/gather_grad.cc @@ -3,6 +3,7 @@ #include "orttraining/training_ops/cuda/tensor/gather_grad.h" #include "orttraining/training_ops/cuda/tensor/gather_grad_impl.h" +#include "orttraining/training_ops/cuda/tensor/thrustallocator.h" #include "core/providers/common.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/cuda/tensor/gather_nd.cc b/orttraining/orttraining/training_ops/cuda/tensor/gather_nd.cc similarity index 62% rename from onnxruntime/core/providers/cuda/tensor/gather_nd.cc rename to orttraining/orttraining/training_ops/cuda/tensor/gather_nd.cc index 53e2fcd153..62f23d39ff 100644 --- a/onnxruntime/core/providers/cuda/tensor/gather_nd.cc +++ b/orttraining/orttraining/training_ops/cuda/tensor/gather_nd.cc @@ -1,13 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/cuda/tensor/gather_nd.h" -#include "core/providers/cuda/tensor/gather_nd_impl.h" +#include "orttraining/training_ops/cuda/tensor/gather_nd.h" +#include "orttraining/training_ops/cuda/tensor/gather_nd_impl.h" #include "core/providers/cuda/shared_inc/cuda_utils.h" namespace onnxruntime { namespace cuda { +namespace { Status CheckBatchDimensionsMatch( size_t num_batch_dimensions, const std::vector>& tensor_shapes) { @@ -37,6 +38,7 @@ Status CheckBatchDimensionsMatch( return Status::OK(); } +} // namespace #define TYPED_FUNCTION_CALL_FWD(T) \ if (T_type == DataTypeImpl::GetType()) { \ @@ -48,10 +50,9 @@ Status CheckBatchDimensionsMatch( GatherNDGradImpl::MappedType>(num_slices, kernel_input_data, kernel_output_data, slice_size, input_slice_offsets_buffer.get()); \ } - template Status GatherNDBase::CommonComputeKernel( - const int64_t batch_dims, + const int64_t axis, const TensorShape& input_shape, const Tensor* kernel_input_tensor, Tensor* kernel_output_tensor, @@ -64,9 +65,9 @@ Status GatherNDBase::CommonComputeKernel( const auto num_slice_dims = indices_shape[indices_shape.NumDimensions() - 1]; const auto num_slices = indices_shape.SizeToDimension(indices_shape.NumDimensions() - 1); - const auto slice_size = input_shape.SizeFromDimension(batch_dims + num_slice_dims); - const auto num_batches = input_shape.SizeToDimension(batch_dims); - const auto input_batch_stride = input_shape.SizeFromDimension(batch_dims); + const auto slice_size = input_shape.SizeFromDimension(axis + num_slice_dims); + const auto num_batches = input_shape.SizeToDimension(axis); + const auto input_batch_stride = input_shape.SizeFromDimension(axis); const auto num_slices_per_batch = num_slices / num_batches; const TIndex* const indices_data = indices_tensor->Data(); @@ -78,7 +79,7 @@ Status GatherNDBase::CommonComputeKernel( auto running_product = slice_size; for (int64_t i = 0; i < num_slice_dims; ++i) { sizes_from_slice_dims[num_slice_dims - 1 - i] = running_product; - running_product *= input_shape[batch_dims + num_slice_dims - 1 - i]; + running_product *= input_shape[axis + num_slice_dims - 1 - i]; } } @@ -91,11 +92,9 @@ Status GatherNDBase::CommonComputeKernel( auto input_slice_offsets_buffer = GetScratchBuffer(num_slices); - TArray input_dims(input_shape.GetDims()); + // TODO error handling for invalid slice indices // TODO reuse computed slice offsets from GatherND in GatherNDGrad ComputeSliceOffsetsImpl( - batch_dims, - input_dims, num_slices, num_slices_per_batch, input_batch_stride, @@ -110,15 +109,10 @@ Status GatherNDBase::CommonComputeKernel( TYPED_FUNCTION_CALL_FWD(MLFloat16); TYPED_FUNCTION_CALL_FWD(double); } else { -#ifdef ENABLE_TRAINING MLDataType T_type = kernel_input_tensor->DataType(); TYPED_FUNCTION_CALL_BWD(float); TYPED_FUNCTION_CALL_BWD(MLFloat16); TYPED_FUNCTION_CALL_BWD(double); -#else - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Gradient computation is only supported in the training mode."); -#endif } return Status::OK(); @@ -127,11 +121,11 @@ Status GatherNDBase::CommonComputeKernel( #undef TYPED_FUNCTION_CALL_FWD #undef TYPED_FUNCTION_CALL_BWD -#define REGISTER_KERNEL_TYPED_GATHER_ND(TIndex, ver) \ +#define REGISTER_KERNEL_TYPED_GATHER_ND(TIndex) \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ GatherND, \ kOnnxDomain, \ - ver, \ + 1, \ TIndex, \ kCudaExecutionProvider, \ KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), \ @@ -139,11 +133,8 @@ Status GatherNDBase::CommonComputeKernel( .TypeConstraint("Tind", DataTypeImpl::GetTensorType()), \ GatherND); -// TODO: decprecate GatherND-1 after updating training models to opset-12 -#ifdef ENABLE_TRAINING -REGISTER_KERNEL_TYPED_GATHER_ND(int64_t, 1) -#endif -REGISTER_KERNEL_TYPED_GATHER_ND(int64_t, 12) +REGISTER_KERNEL_TYPED_GATHER_ND(int64_t) +REGISTER_KERNEL_TYPED_GATHER_ND(int32_t) template Status GatherND::ComputeInternal(OpKernelContext* context) const { @@ -160,14 +151,14 @@ Status GatherND::ComputeInternal(OpKernelContext* context) const { "indices tensor must has rank larger than 0"); } - auto last_indices_dimension = batch_dims_ + indices_shape[indices_shape.NumDimensions() - 1]; + auto last_indices_dimension = axis_ + indices_shape[indices_shape.NumDimensions() - 1]; if (last_indices_dimension > static_cast(input_shape.NumDimensions())) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "last dimension of indices must not be larger than rank of input tensor"); } ORT_RETURN_IF_ERROR(CheckBatchDimensionsMatch( - static_cast(batch_dims_), {input_shape, indices_shape})); + static_cast(axis_), {input_shape, indices_shape})); //Output shape std::vector shape(indices_shape.GetDims().begin(), indices_shape.GetDims().end() - 1); @@ -176,10 +167,67 @@ Status GatherND::ComputeInternal(OpKernelContext* context) const { auto output_tensor = context->Output(0, TensorShape(shape)); //Compute - auto status = CommonComputeKernel(batch_dims_, input_shape, input_tensor, output_tensor, indices_shape, indices_tensor, true); + auto status = CommonComputeKernel(axis_, input_shape, input_tensor, output_tensor, indices_shape, indices_tensor, true); return status; } +#define REGISTER_KERNEL_TYPED_GATHER_ND_GRAD(TIndex) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + GatherNDGrad, \ + kOnnxDomain, \ + 1, \ + TIndex, \ + kCudaExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) \ + .TypeConstraint("Tind", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(0), \ + GatherNDGrad); + +REGISTER_KERNEL_TYPED_GATHER_ND_GRAD(int64_t) +REGISTER_KERNEL_TYPED_GATHER_ND_GRAD(int32_t) + +template +Status GatherNDGrad::ComputeInternal(OpKernelContext* context) const { + auto shape_tensor = context->Input(0); + auto indices_tensor = context->Input(1); + auto update_tensor = context->Input(2); + ORT_RETURN_IF_NOT(shape_tensor != nullptr); + ORT_RETURN_IF_NOT(indices_tensor != nullptr); + ORT_RETURN_IF_NOT(update_tensor != nullptr); + + auto indices_shape = indices_tensor->Shape(); + auto update_shape = update_tensor->Shape(); + + if (indices_shape.NumDimensions() == 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "indices tensor must has rank larger than 0"); + } + + auto last_indices_dimension = axis_ + indices_shape[indices_shape.NumDimensions() - 1]; + + //Output + auto shape_data = shape_tensor->Data(); + auto input_shape = TensorShape(shape_data, shape_tensor->SizeInBytes() / sizeof(shape_tensor->DataType())); + + if (last_indices_dimension > static_cast(input_shape.NumDimensions())) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "last dimension of indices must not be larger than rank of input tensor"); + } + + ORT_RETURN_IF_ERROR(CheckBatchDimensionsMatch( + static_cast(axis_), {input_shape, indices_shape, update_shape})); + + auto output_tensor = context->Output(0, input_shape); + + // TODO this memset can be expensive, a sparse tensor representation would help here + CUDA_RETURN_IF_ERROR(cudaMemsetAsync(output_tensor->MutableDataRaw(), 0, output_tensor->SizeInBytes())); + + auto status = CommonComputeKernel(axis_, input_shape, update_tensor, output_tensor, indices_shape, indices_tensor, false); + return status; +} + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/gather_nd.h b/orttraining/orttraining/training_ops/cuda/tensor/gather_nd.h similarity index 72% rename from onnxruntime/core/providers/cuda/tensor/gather_nd.h rename to orttraining/orttraining/training_ops/cuda/tensor/gather_nd.h index 9e5f5116a5..70c429d6f9 100644 --- a/onnxruntime/core/providers/cuda/tensor/gather_nd.h +++ b/orttraining/orttraining/training_ops/cuda/tensor/gather_nd.h @@ -10,21 +10,17 @@ namespace onnxruntime { namespace cuda { -Status CheckBatchDimensionsMatch( - size_t num_batch_dimensions, - const std::vector>& tensor_shapes); - class GatherNDBase : public CudaKernel { public: GatherNDBase(const OpKernelInfo& info) : CudaKernel(info) { - info.GetAttrOrDefault("batch_dims", &batch_dims_, static_cast(0)); - ORT_ENFORCE(batch_dims_ >= 0); + info.GetAttrOrDefault("axis", &axis_, static_cast(0)); + ORT_ENFORCE(axis_ >= 0); } protected: template Status CommonComputeKernel( - const int64_t batch_dims, + const int64_t axis, const TensorShape& input_shape, const Tensor* input_tensor, Tensor* output_tensor, @@ -32,7 +28,7 @@ class GatherNDBase : public CudaKernel { const Tensor* indices_tensor, const bool fwd) const; - int64_t batch_dims_; + int64_t axis_; }; template @@ -42,5 +38,12 @@ class GatherND final : public GatherNDBase { Status ComputeInternal(OpKernelContext* context) const override; }; +template +class GatherNDGrad final : public GatherNDBase { + public: + GatherNDGrad(const OpKernelInfo& info) : GatherNDBase(info) {} + Status ComputeInternal(OpKernelContext* context) const override; +}; + } // namespace cuda } // namespace onnxruntime \ No newline at end of file diff --git a/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_gard_impl.cu b/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_gard_impl.cu deleted file mode 100644 index 8c37836705..0000000000 --- a/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_gard_impl.cu +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/cuda/tensor/gather_nd_impl.h" - -#include "core/providers/cuda/cu_inc/common.cuh" -#include "core/providers/cuda/atomic/common.cuh" - -namespace onnxruntime { -namespace cuda { - -template -__global__ void _GatherNDGradKernel( - const size_t num_slices, - const T* update_data, - T* output_data, - const size_t slice_size, - const int64_t* slice_offsets) { - CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(i, num_slices * slice_size); - uint64_t slice_offset = slice_offsets[i / slice_size]; - size_t j = i % slice_size; - atomic_add(output_data + slice_offset + j, update_data[i]); -}; - -template -void GatherNDGradImpl( - const size_t num_slices, - const void* update_data, - void* output_data, - const size_t slice_size, - const int64_t* input_slice_offsets_data) { - const auto blocks_per_grid = CeilDiv(num_slices * slice_size, GridDim::maxThreadsPerBlock); - _GatherNDGradKernel<<>>( - num_slices, static_cast(update_data), static_cast(output_data), slice_size, input_slice_offsets_data); -} - -#define SPECIALIZED_GRAD_IMPL(T) \ - template void GatherNDGradImpl(const size_t num_slices, const void* update_data, void* output_data, const size_t slice_size, const int64_t* input_slice_offsets_data) - -SPECIALIZED_GRAD_IMPL(float); -#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 -SPECIALIZED_GRAD_IMPL(half); -SPECIALIZED_GRAD_IMPL(double); -#endif - -} // namespace cuda -} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_grad.cc b/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_grad.cc deleted file mode 100644 index d7878cf0a6..0000000000 --- a/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_grad.cc +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "orttraining/training_ops/cuda/tensor/gather_nd_grad.h" -#include "core/providers/cuda/shared_inc/cuda_utils.h" - -namespace onnxruntime { -namespace cuda { - -#define REGISTER_KERNEL_TYPED_GATHER_ND_GRAD(TIndex) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - GatherNDGrad, \ - kMSDomain, \ - 1, \ - TIndex, \ - kCudaExecutionProvider, \ - KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), \ - DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) \ - .TypeConstraint("Tind", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ - .InputMemoryType(0), \ - GatherNDGrad); - -REGISTER_KERNEL_TYPED_GATHER_ND_GRAD(int64_t) - -template -Status GatherNDGrad::ComputeInternal(OpKernelContext* context) const { - auto shape_tensor = context->Input(0); - auto indices_tensor = context->Input(1); - auto update_tensor = context->Input(2); - ORT_RETURN_IF_NOT(shape_tensor != nullptr); - ORT_RETURN_IF_NOT(indices_tensor != nullptr); - ORT_RETURN_IF_NOT(update_tensor != nullptr); - - auto indices_shape = indices_tensor->Shape(); - auto update_shape = update_tensor->Shape(); - - if (indices_shape.NumDimensions() == 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "indices tensor must has rank larger than 0"); - } - - auto last_indices_dimension = batch_dims_ + indices_shape[indices_shape.NumDimensions() - 1]; - - //Output - auto shape_data = shape_tensor->Data(); - auto input_shape = TensorShape(shape_data, shape_tensor->SizeInBytes() / sizeof(shape_tensor->DataType())); - - if (last_indices_dimension > static_cast(input_shape.NumDimensions())) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "last dimension of indices must not be larger than rank of input tensor"); - } - - ORT_RETURN_IF_ERROR(CheckBatchDimensionsMatch( - static_cast(batch_dims_), {input_shape, indices_shape, update_shape})); - - auto output_tensor = context->Output(0, input_shape); - - // TODO this memset can be expensive, a sparse tensor representation would help here - CUDA_RETURN_IF_ERROR(cudaMemsetAsync(output_tensor->MutableDataRaw(), 0, output_tensor->SizeInBytes())); - - auto status = CommonComputeKernel(batch_dims_, input_shape, update_tensor, output_tensor, indices_shape, indices_tensor, false); - return status; -} - -} // namespace cuda -} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_grad.h b/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_grad.h deleted file mode 100644 index d29d85befa..0000000000 --- a/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_grad.h +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/common.h" -#include "core/framework/op_kernel.h" -#include "core/providers/cuda/cuda_common.h" -#include "core/providers/cuda/tensor/gather_nd.h" - -namespace onnxruntime { -namespace cuda { - -template -class GatherNDGrad final : public GatherNDBase { - public: - GatherNDGrad(const OpKernelInfo& info) : GatherNDBase(info) {} - Status ComputeInternal(OpKernelContext* context) const override; -}; - -} // namespace cuda -} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/cuda/tensor/gather_nd_impl.cu b/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_impl.cu similarity index 68% rename from onnxruntime/core/providers/cuda/tensor/gather_nd_impl.cu rename to orttraining/orttraining/training_ops/cuda/tensor/gather_nd_impl.cu index 91b0169e68..d754dcfe9f 100644 --- a/onnxruntime/core/providers/cuda/tensor/gather_nd_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_impl.cu @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/cuda/tensor/gather_nd_impl.h" +#include "orttraining/training_ops/cuda/tensor/gather_nd_impl.h" #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/atomic/common.cuh" @@ -11,8 +11,6 @@ namespace cuda { template __global__ void _ComputeSliceOffsetsKernel( - const int64_t batch_dims, - const TArray input_dims, const size_t num_slices, const size_t num_slices_per_batch, const size_t input_batch_stride, @@ -28,12 +26,7 @@ __global__ void _ComputeSliceOffsetsKernel( const TIndex* const slice_indices = indices_data + slice_idx * num_slice_dims; size_t relative_slice_offset = 0; for (size_t dim_idx = 0; dim_idx < num_slice_dims; ++dim_idx) { - int64_t index = static_cast(slice_indices[dim_idx]); - const size_t input_dim_idx = batch_dims + dim_idx; - CUDA_KERNEL_ASSERT(index >= -input_dims[input_dim_idx] && index < input_dims[input_dim_idx]); - if (index < 0) index += input_dims[input_dim_idx]; - - relative_slice_offset += index * sizes_from_slice_dims_data[dim_idx]; + relative_slice_offset += static_cast(slice_indices[dim_idx]) * sizes_from_slice_dims_data[dim_idx]; } input_slice_offsets_data[slice_idx] = base_offset + relative_slice_offset; @@ -51,10 +44,21 @@ __global__ void _GatherNDKernel( output_data[i] = input_data[slice_offset + i % slice_size]; }; +template +__global__ void _GatherNDGradKernel( + const size_t num_slices, + const T* update_data, + T* output_data, + const size_t slice_size, + const int64_t* slice_offsets) { + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(i, num_slices * slice_size); + uint64_t slice_offset = slice_offsets[i / slice_size]; + size_t j = i % slice_size; + atomic_add(output_data + slice_offset + j, update_data[i]); +}; + template void ComputeSliceOffsetsImpl( - const int64_t batch_dims, - const TArray input_dims, const size_t num_slices, const size_t num_slices_per_batch, const size_t input_batch_stride, @@ -64,8 +68,6 @@ void ComputeSliceOffsetsImpl( int64_t* const input_slice_offsets_data) { // num_slices elements const auto blocks_per_grid = CeilDiv(num_slices, GridDim::maxThreadsPerBlock); _ComputeSliceOffsetsKernel<<>>( - batch_dims, - input_dims, num_slices, num_slices_per_batch, input_batch_stride, @@ -87,28 +89,45 @@ void GatherNDImpl( num_slices, static_cast(input_data), static_cast(output_data), slice_size, input_slice_offsets_data); } +template +void GatherNDGradImpl( + const size_t num_slices, + const void* update_data, + void* output_data, + const size_t slice_size, + const int64_t* input_slice_offsets_data) { + const auto blocks_per_grid = CeilDiv(num_slices * slice_size, GridDim::maxThreadsPerBlock); + _GatherNDGradKernel<<>>( + num_slices, static_cast(update_data), static_cast(output_data), slice_size, input_slice_offsets_data); +} + #define SPECIALIZED_COMPUTE_SLICE_OFFSETS_IMPL(TIndex) \ template void ComputeSliceOffsetsImpl( \ - const int64_t batch_dims, \ - const TArray input_dims, \ const size_t num_slices, \ const size_t num_slices_per_batch, \ const size_t input_batch_stride, \ const size_t num_slice_dims, \ const int64_t* const sizes_from_slice_dims_data, \ const TIndex* const indices_data, \ - int64_t* const input_slice_offsets_data); + int64_t* const input_slice_offsets_data) #define SPECIALIZED_IMPL(T) \ - template void GatherNDImpl(const size_t num_slices, const void* input_data, void* output_data, const size_t slice_size, const int64_t* input_slice_offsets_data); + template void GatherNDImpl(const size_t num_slices, const void* input_data, void* output_data, const size_t slice_size, const int64_t* input_slice_offsets_data) -SPECIALIZED_COMPUTE_SLICE_OFFSETS_IMPL(int32_t) -SPECIALIZED_COMPUTE_SLICE_OFFSETS_IMPL(int64_t) +#define SPECIALIZED_GRAD_IMPL(T) \ + template void GatherNDGradImpl(const size_t num_slices, const void* update_data, void* output_data, const size_t slice_size, const int64_t* input_slice_offsets_data) -SPECIALIZED_IMPL(float) +SPECIALIZED_COMPUTE_SLICE_OFFSETS_IMPL(int32_t); +SPECIALIZED_COMPUTE_SLICE_OFFSETS_IMPL(int64_t); + +SPECIALIZED_IMPL(float); +SPECIALIZED_GRAD_IMPL(float); #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 -SPECIALIZED_IMPL(half) -SPECIALIZED_IMPL(double) +SPECIALIZED_IMPL(half); +SPECIALIZED_GRAD_IMPL(half); + +SPECIALIZED_IMPL(double); +SPECIALIZED_GRAD_IMPL(double); #endif } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/tensor/gather_nd_impl.h b/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_impl.h similarity index 91% rename from onnxruntime/core/providers/cuda/tensor/gather_nd_impl.h rename to orttraining/orttraining/training_ops/cuda/tensor/gather_nd_impl.h index e989fb330a..aa315d9fa3 100644 --- a/onnxruntime/core/providers/cuda/tensor/gather_nd_impl.h +++ b/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_impl.h @@ -9,8 +9,6 @@ namespace cuda { template void ComputeSliceOffsetsImpl( - const int64_t batch_dims, - const TArray input_dims, const size_t num_slices, const size_t num_slices_per_batch, const size_t input_batch_stride, @@ -27,7 +25,6 @@ void GatherNDImpl( const size_t slice_size, const int64_t* input_slice_offsets_data); -#ifdef ENABLE_TRAINING template void GatherNDGradImpl( const size_t num_slices, @@ -35,7 +32,6 @@ void GatherNDGradImpl( void* output_data, const size_t slice_size, const int64_t* input_slice_offsets_data); -#endif } // namespace cuda } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/tensor/thrustallocator.h b/orttraining/orttraining/training_ops/cuda/tensor/thrustallocator.h new file mode 100644 index 0000000000..182b14b1a4 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/thrustallocator.h @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/allocator.h" + +namespace onnxruntime { +namespace cuda { + +class ThrustAllocator { + public: + typedef char value_type; + + ThrustAllocator(IAllocator* alloc) : alloc_(alloc) {} + + char* allocate(std::ptrdiff_t size) { + return static_cast(alloc_->Alloc(size)); + } + + void deallocate(char* p, size_t /*size*/) { + alloc_->Free(p); + } + + private: + IAllocator* alloc_; +}; + +} // namespace cuda +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda_training_kernels.cc similarity index 97% rename from orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc rename to orttraining/orttraining/training_ops/cuda_training_kernels.cc index 537864f1a1..76c04462ae 100644 --- a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda_training_kernels.cc @@ -88,10 +88,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_MLFloat16, DropoutGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_float, DropoutGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_double, DropoutGrad); - -// TODO: decprecate GatherND-1 after updating training models to opset-12 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int64_t, GatherND); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, GatherNDGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int32_t, GatherND); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int64_t, GatherNDGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int32_t, GatherNDGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DivGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, DivGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DivGrad); @@ -131,9 +131,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Send class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Recv); #endif -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, RecordEvent); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, WaitEvent); - #ifdef USE_NCCL class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, NcclAllReduce); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, NcclAllGather); @@ -212,10 +209,10 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - - // TODO: decprecate GatherND-1 after updating training models to opset-12 BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, // BuildKernelCreateInfo, @@ -266,9 +263,6 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, #endif - BuildKernelCreateInfo, - BuildKernelCreateInfo, - #ifdef USE_NCCL BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.h b/orttraining/orttraining/training_ops/cuda_training_kernels.h similarity index 100% rename from orttraining/orttraining/training_ops/cuda/cuda_training_kernels.h rename to orttraining/orttraining/training_ops/cuda_training_kernels.h From f487cc0b2835892dcd223b9eb9093f62e64ea294 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 28 Apr 2020 00:03:16 -0700 Subject: [PATCH 8/8] Fix Reshape Fusion with graph inputs (#3729) Use NodeArg to check root input; Add a check on constant initializer --- onnxruntime/core/optimizer/reshape_fusion.cc | 17 ++++++----- onnxruntime/core/optimizer/utils.cc | 6 +++- onnxruntime/core/optimizer/utils.h | 2 +- .../test/optimizer/graph_transform_test.cc | 21 ++++++++++++++ .../transform/fusion/reshape_fusion_gen.py | 27 ++++++++++++++++++ .../reshape_fusion_with_graph_inputs.onnx | Bin 0 -> 446 bytes 6 files changed, 62 insertions(+), 11 deletions(-) create mode 100644 onnxruntime/test/testdata/transform/fusion/reshape_fusion_with_graph_inputs.onnx diff --git a/onnxruntime/core/optimizer/reshape_fusion.cc b/onnxruntime/core/optimizer/reshape_fusion.cc index 5d39939d30..399dcf744c 100644 --- a/onnxruntime/core/optimizer/reshape_fusion.cc +++ b/onnxruntime/core/optimizer/reshape_fusion.cc @@ -46,7 +46,7 @@ each of which is a constant initializer or a Shape->Gather->Unsqueeze chain with index corresponding to the index of the argument.) Before fusion: - [Sub-graph Root Node ] + [Sub-graph Root] | / \ | Shape Shape | | | @@ -61,13 +61,14 @@ Before fusion: Reshape After fusion: - [Sub-graph Root Node] (Constant Initializer) + [Sub-graph Root] (Constant Initializer) \ [0, a, 0, b] \ / Reshape */ bool ReshapeFusion::Fuse_Subgraph1(Node& reshape, Graph& graph, const logging::Logger& logger) { - const Node* p_root = graph_utils::GetInputNode(reshape, 0); + // The root could be either a graph input or a node so use node arg to compare. + const NodeArg& root_input = *(reshape.InputDefs()[0]); const Node* p_concat = graph_utils::GetInputNode(reshape, 1); if (nullptr == p_concat) { @@ -90,11 +91,8 @@ bool ReshapeFusion::Fuse_Subgraph1(Node& reshape, Graph& graph, const logging::L enum class NodeType { Unsqueeze, Gather, Shape }; std::set> candidates_for_removal; for (int i = 0; i < concat_input_count; ++i) { - // First check if the i-th argument is an initializer. - // We do not check whether the initializer is constant. - // Some model uses constant initializer and some does not. - // Here we assume that no one will override the initializer using graph input. - if (optimizer_utils::AppendTensorFromInitializer(graph, *(concat.InputDefs()[i]), shape_value)) { + // First check if the i-th argument is a constant initializer. + if (optimizer_utils::AppendTensorFromInitializer(graph, *(concat.InputDefs()[i]), shape_value, true)) { continue; } @@ -113,7 +111,8 @@ bool ReshapeFusion::Fuse_Subgraph1(Node& reshape, Graph& graph, const logging::L const Node& gather = edges[1]->GetNode(); const Node& shape = edges[2]->GetNode(); - if (graph_utils::GetInputNode(shape, 0) != p_root) { + const NodeArg& shape_input = *(shape.InputDefs()[0]); + if (shape_input.Name() != root_input.Name()) { return false; } diff --git a/onnxruntime/core/optimizer/utils.cc b/onnxruntime/core/optimizer/utils.cc index 4aff66352a..052f5ff67f 100644 --- a/onnxruntime/core/optimizer/utils.cc +++ b/onnxruntime/core/optimizer/utils.cc @@ -141,7 +141,11 @@ bool IsAttributeWithExpectedValues(const Node& node, const std::string& attr_nam return true; } -bool AppendTensorFromInitializer(const Graph& graph, const NodeArg& input_arg, std::vector& data) { +bool AppendTensorFromInitializer(const Graph& graph, const NodeArg& input_arg, std::vector& data, bool require_constant) { + if (require_constant && !graph_utils::IsConstantInitializer(graph, input_arg.Name(), true)) { + return false; + } + const ONNX_NAMESPACE::TensorProto* tensor_proto = nullptr; if (!graph.GetInitializedTensor(input_arg.Name(), tensor_proto)) { return false; diff --git a/onnxruntime/core/optimizer/utils.h b/onnxruntime/core/optimizer/utils.h index 7887bfd360..0cafe0d49c 100644 --- a/onnxruntime/core/optimizer/utils.h +++ b/onnxruntime/core/optimizer/utils.h @@ -42,7 +42,7 @@ bool IsAttributeWithExpectedValues(const Node& node, const std::string& attr_nam /** Get values of an integer tensor from initializer, and append them to a vector. @remarks only support int32 and int64 tensor. This function does not clear vector before appending. */ -bool AppendTensorFromInitializer(const Graph& graph, const NodeArg& input_arg, std::vector& data); +bool AppendTensorFromInitializer(const Graph& graph, const NodeArg& input_arg, std::vector& data, bool require_constant = true); /** Check Shape of node input or output. @remarks when expected dim value > 0, the dim is expected to known and match the dim value. diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 7ff3f75e3d..6aae2911c0 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -1116,6 +1116,27 @@ TEST_F(GraphTransformationTests, ReshapeFusionInternalReuseTest) { } } + +TEST_F(GraphTransformationTests, ReshapeFusionGraphInputsTest) { + auto model_uri = MODEL_FOLDER "fusion/reshape_fusion_with_graph_inputs.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level1); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_); + ASSERT_TRUE(ret.IsOK()); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["Shape"], 1); + ASSERT_EQ(op_to_count["Gather"], 1); + ASSERT_EQ(op_to_count["Unsqueeze"], 1); + ASSERT_EQ(op_to_count["Concat"], 1); + ASSERT_EQ(op_to_count["Reshape"], 1); +} + + TEST_F(GraphTransformationTests, ExpandElimination) { auto model_uri = MODEL_FOLDER "expand_elimination.onnx"; std::shared_ptr model; diff --git a/onnxruntime/test/testdata/transform/fusion/reshape_fusion_gen.py b/onnxruntime/test/testdata/transform/fusion/reshape_fusion_gen.py index 3f01ffea8f..156d9a2147 100644 --- a/onnxruntime/test/testdata/transform/fusion/reshape_fusion_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/reshape_fusion_gen.py @@ -70,3 +70,30 @@ graph = helper.make_graph( save_model(graph, 'reshape_fusion_internal_node_is_graph_output.onnx') + + +graph = helper.make_graph( + [ # nodes + helper.make_node("Shape", ["query"], ["shape0_out"], "shape0"), + helper.make_node("Gather", ["shape0_out", "indices0"], ["gather0_out"], "gather0", axis=0), + helper.make_node("Unsqueeze", ["gather0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]), + helper.make_node("Concat", ["a", "unsqueeze0_out"], ["concat_out"], "concat", axis=0), + helper.make_node("Reshape", ["doc_word_mask", "concat_out"], ["Result"], "reshape"), + ], + "Reshape_Fusion", #name + [ # inputs + helper.make_tensor_value_info('query', TensorProto.FLOAT, [1, 50]), + helper.make_tensor_value_info('doc_word_mask', TensorProto.FLOAT, [1, 200, 50]), + ], + [ # outputs + helper.make_tensor_value_info('Result', TensorProto.FLOAT, [10, 20, 'unk']), + ], + [ # initializers + helper.make_tensor('a', TensorProto.INT64, [1], [-1]), + helper.make_tensor('indices0', TensorProto.INT64, [], [1]), + ] +) + +save_model(graph, 'reshape_fusion_with_graph_inputs.onnx') + + diff --git a/onnxruntime/test/testdata/transform/fusion/reshape_fusion_with_graph_inputs.onnx b/onnxruntime/test/testdata/transform/fusion/reshape_fusion_with_graph_inputs.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e3609a566defdc46008d3c9b724e6cf366208cd9 GIT binary patch literal 446 zcmZ8d%TB{U3}h3MEEAfREeO;L(uz|~6%v<<0|W`~ATGT`OR}O-A%SdGJidmX;2-!a z*yaJDmMq!hu{}0*XP-7WoM$b|HwLv3PjVU;x|+C6%$=JRhI90apjn<~iBwbO--~J? z%cjR`6YgHsXy{{8yk_cPdcJra0()@*2s_) zc_OMtbrW;CAy_QNpR1>e1_2t|%!0L1sv_X8SRaWHT zn39n+;_G4uWIsR6brVH6f6#*gPmYbw|nJBBc*2{!$zS$BtM1+hJt Tg=}cbFun_i**afF*2w<>5r%(D literal 0 HcmV?d00001