diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 766c6d96b7..5bf8ca9275 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -1110,6 +1110,14 @@ class Graph { // Graph value_info. std::vector value_info_; + // Strings which have been used as node names. + // New node name should not conflict with this set. + std::unordered_set generated_node_names_; + + // Strings which have been used as node_arg names. + // New node_arg name should not conflict this this set. + std::unordered_set generated_node_arg_names_; + // All node args owned by <*this> graph. Key is node arg name. std::unordered_map> node_args_; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index 6135bf07c1..a6b4cfb8b7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -3,7 +3,6 @@ #include "attention.h" #include "core/framework/tensorprotoutils.h" -#include "core/providers/cuda/cudnn_common.h" #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" #include "attention_impl.h" diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.h b/onnxruntime/contrib_ops/cuda/bert/attention.h index 9ec55f1793..b32ca9ee35 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention.h @@ -5,7 +5,7 @@ #include "core/common/common.h" #include "core/framework/op_kernel.h" -#include "core/providers/cuda/cudnn_common.h" +#include "core/providers/cuda/cuda_common.h" #include "contrib_ops/cpu/bert/attention.h" namespace onnxruntime { diff --git a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm.cc b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm.cc index 8df41e9aab..d873971e8c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm.cc +++ b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm.cc @@ -2,7 +2,6 @@ // Licensed under the MIT License. #include "core/providers/common.h" -#include "core/providers/cuda/cudnn_common.h" #include "core/framework/tensorprotoutils.h" #include "onnx/defs/tensor_proto_util.h" #include "contrib_ops/cpu/bert/embed_layer_norm_helper.h" diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc index 1d9fe64336..0c965cd468 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc +++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc @@ -2,7 +2,7 @@ // Licensed under the MIT License. #include "core/providers/common.h" -#include "core/providers/cuda/cudnn_common.h" +#include "core/providers/cuda/cuda_common.h" #include "core/framework/tensorprotoutils.h" #include "fast_gelu.h" #include "fast_gelu_impl.h" diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc index 1eeaeb773d..61a1d64ca1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc @@ -2,7 +2,7 @@ // Licensed under the MIT License. #include "core/providers/common.h" -#include "core/providers/cuda/cudnn_common.h" +#include "core/providers/cuda/cuda_common.h" #include "core/framework/tensorprotoutils.h" #include "onnx/defs/tensor_proto_util.h" #include "skip_layer_norm.h" diff --git a/onnxruntime/contrib_ops/cuda/layer_norm.cc b/onnxruntime/contrib_ops/cuda/layer_norm.cc index 5bb3b237ab..bd6d14eef0 100644 --- a/onnxruntime/contrib_ops/cuda/layer_norm.cc +++ b/onnxruntime/contrib_ops/cuda/layer_norm.cc @@ -5,7 +5,7 @@ #include "layer_norm_impl.h" #include "core/providers/common.h" -#include "core/providers/cuda/cudnn_common.h" +#include "core/providers/cuda/cuda_common.h" namespace onnxruntime { namespace contrib { diff --git a/onnxruntime/contrib_ops/cuda/layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/layer_norm_impl.cu index d6867949d0..747a2ff70e 100644 --- a/onnxruntime/contrib_ops/cuda/layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/layer_norm_impl.cu @@ -107,13 +107,14 @@ __device__ void cuWelfordMuSigma2( cuWelfordOnlineSum(curr, mu, sigma2, count); } // intra-warp reductions - for (int l = 0; l <= 4; ++l) { - int srcLaneB = (threadIdx.x + (1 << l)) & 31; - U muB = WARP_SHFL(mu, srcLaneB); - U countB = WARP_SHFL(count, srcLaneB); - U sigma2B = WARP_SHFL(sigma2, srcLaneB); + #pragma unroll + for (int stride = GPU_WARP_SIZE / 2; stride > 0; stride /= 2) { + U muB = WARP_SHFL_DOWN(mu, stride); + U countB = WARP_SHFL_DOWN(count, stride); + U sigma2B = WARP_SHFL_DOWN(sigma2, stride); cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); } + // threadIdx.x == 0 has correct values for each warp // inter-warp reductions if (blockDim.y > 1) { @@ -192,8 +193,8 @@ __device__ void cuWelfordMuSigma2( for (; l + 7 < n2; l += 8 * numx) { for (int k = 0; k < 8; k += 2) { float2 curr = __half22float2(*((__half2*)(lvals + l + k))); - cuWelfordOnlineSum(curr.x, mu, sigma2, count); - cuWelfordOnlineSum(curr.y, mu, sigma2, count); + cuWelfordOnlineSum(static_cast(curr.x), mu, sigma2, count); + cuWelfordOnlineSum(static_cast(curr.y), mu, sigma2, count); } } for (; l < n2; ++l) { @@ -201,13 +202,14 @@ __device__ void cuWelfordMuSigma2( cuWelfordOnlineSum(curr, mu, sigma2, count); } // intra-warp reductions - for (int l = 0; l <= 4; ++l) { - int srcLaneB = (threadIdx.x + (1 << l)) & 31; - float muB = WARP_SHFL(mu, srcLaneB); - float countB = WARP_SHFL(count, srcLaneB); - float sigma2B = WARP_SHFL(sigma2, srcLaneB); + #pragma unroll + for (int stride = GPU_WARP_SIZE / 2; stride > 0; stride /= 2) { + float muB = WARP_SHFL_DOWN(mu, stride); + float countB = WARP_SHFL_DOWN(count, stride); + float sigma2B = WARP_SHFL_DOWN(sigma2, stride); cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); } + // threadIdx.x == 0 has correct values for each warp // inter-warp reductions if (blockDim.y > 1) { @@ -310,7 +312,7 @@ __global__ void cuApplyLayerNorm( // 1) blockDim.x == GPU_WARP_SIZE // 2) Tensors are contiguous // - for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { + for (int i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { SharedMemory shared; U* buf = shared.getPointer(); U mu, sigma2; diff --git a/onnxruntime/contrib_ops/cuda/tensor/image_scaler.h b/onnxruntime/contrib_ops/cuda/tensor/image_scaler.h index 70f6590f62..de431d45f5 100644 --- a/onnxruntime/contrib_ops/cuda/tensor/image_scaler.h +++ b/onnxruntime/contrib_ops/cuda/tensor/image_scaler.h @@ -5,7 +5,7 @@ #include "core/common/common.h" #include "core/framework/op_kernel.h" -#include "core/providers/cuda/cudnn_common.h" +#include "core/providers/cuda/cuda_common.h" namespace onnxruntime { namespace contrib { diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 16e2a504a2..40dfa89a7e 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -2376,16 +2376,47 @@ Node& Graph::AddNode(const NodeProto& node_proto, } std::string Graph::GenerateNodeArgName(const std::string& base_name) { + // Check if base_name has been used in as any of node_args_' names. + bool found = node_args_.find(base_name) != node_args_.end(); + // Check if base_name has been generated by this function. + // If not, add base_name into name set and return the base_name + // as the generated name. + found |= generated_node_arg_names_.find(base_name) != generated_node_arg_names_.end(); + if (!found) { + generated_node_arg_names_.insert(base_name); + return base_name; + } + + // base_name has been used by another node. Because two node_arg's cannot have + // the sam name, we are going to generate another string. std::string new_name; do { std::ostringstream str; str << base_name << "_" << name_generator_++; new_name = str.str(); - } while (node_args_.find(new_name) != node_args_.end()); + // If node_args_ or generated_node_arg_names_ contains new_name, we go to the next iteration. + } while (node_args_.find(new_name) != node_args_.end() || + generated_node_arg_names_.find(new_name) != generated_node_arg_names_.end()); + + // Now new_name is different than any of existing node_arg names. + // Register new_name so that it won't be used again. + generated_node_arg_names_.insert(new_name); + return new_name; } std::string Graph::GenerateNodeName(const std::string& base_name) { + // Check if base_name has been used in as any of nodes_' names. + bool found = std::find_if(nodes_.cbegin(), nodes_.cend(), [&base_name](const std::unique_ptr& n) { + return (n != nullptr) && (n->Name() == base_name);}) != nodes_.end(); + // Check if base_name has been generated by this function. + found |= generated_node_names_.find(base_name) != generated_node_names_.end(); + if (!found) { + // Register base_name so that it won't be used again. + generated_node_names_.insert(base_name); + return base_name; + } + std::string new_name; bool keep_going = true; @@ -2394,11 +2425,18 @@ std::string Graph::GenerateNodeName(const std::string& base_name) { str << base_name << "_" << name_generator_++; new_name = str.str(); + // Check if new_name has been used in as any of nodes_' names. keep_going = std::find_if(nodes_.cbegin(), nodes_.cend(), [&new_name](const std::unique_ptr& n) { return (n != nullptr) && (n->Name() == new_name); }) != nodes_.end(); + // Check if new_name has been generated by this function. + keep_going |= generated_node_names_.find(new_name) != generated_node_names_.end(); } while (keep_going); + // Now new_name is different than any of existing node names. + // Register new_name so that it won't be used again. + generated_node_names_.insert(new_name); + return new_name; } diff --git a/onnxruntime/core/optimizer/expand_elimination.cc b/onnxruntime/core/optimizer/expand_elimination.cc new file mode 100644 index 0000000000..c6ae692f2d --- /dev/null +++ b/onnxruntime/core/optimizer/expand_elimination.cc @@ -0,0 +1,74 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/logging/logging.h" +#include "core/framework/tensorprotoutils.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/rewrite_rule.h" +#include "core/optimizer/expand_elimination.h" +#include "core/graph/graph.h" +#include "core/graph/graph_utils.h" + +namespace onnxruntime { + +Status ExpandElimination::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { + if (graph_utils::RemoveNode(graph, node)) { + rule_effect = RewriteRuleEffect::kRemovedCurrentNode; + } + + return Status::OK(); +} + +bool ExpandElimination::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const { + if (!graph_utils::CanRemoveNode(graph, node, logger)) { + return false; + } + + // 1. Check if has input shape. + const auto* input_shape = node.InputDefs()[0]->Shape(); + if (input_shape == nullptr) { + return false; + } + + // 2. Get target shape if it's constant. + const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); + if (tensor_proto == nullptr || tensor_proto->dims_size() != 1 || tensor_proto->dims(0) <= 0) { + return false; + } + + auto initializer = onnxruntime::make_unique(*tensor_proto, graph.ModelPath()); + if (initializer->data_type() != ONNX_NAMESPACE::TensorProto_DataType_INT64) { + return false; + } + + const int64_t* target_shapes = initializer->data(); + + // Check the dimensions starting at the trailing dimension. + int i = input_shape->dim_size() - 1; + int j = static_cast(tensor_proto->dims(0) - 1); + + // The Expand produces same input tensor only when target dimension size is not greater than input's. + if (i < j) { + return false; + } + + while (i >= 0 && j >= 0) { + auto dim = input_shape->dim(i); + if (utils::HasDimValue(dim)) { + auto dim_value = dim.dim_value(); + if (dim_value != target_shapes[j] && target_shapes[j] > 1) { + return false; + } + } else if (target_shapes[j] > 1) { + return false; + } + + --i; + --j; + } + + + return true; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/expand_elimination.h b/onnxruntime/core/optimizer/expand_elimination.h new file mode 100644 index 0000000000..b31ce8fbca --- /dev/null +++ b/onnxruntime/core/optimizer/expand_elimination.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/rewrite_rule.h" + +namespace onnxruntime { + +/** +@Class ExpandElimination + +Rewrite rule that eliminates Expand nodes if the node generate the same tensor as the input tensor. + +It is attempted to be triggered only on nodes with op type "Expand". +*/ +class ExpandElimination : public RewriteRule { + public: + ExpandElimination() noexcept : RewriteRule("ExpandElimination") {} + + std::vector TargetOpTypes() const noexcept override { + return {"Expand"}; + } + + private: + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; + + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 2ef55402f0..86ea68b67c 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -27,6 +27,7 @@ #include "core/optimizer/embed_layer_norm_fusion.h" #include "core/optimizer/reshape_fusion.h" #include "core/optimizer/attention_fusion.h" +#include "core/optimizer/expand_elimination.h" #include "core/optimizer/cast_elimination.h" #include "core/mlas/inc/mlas.h" @@ -47,6 +48,7 @@ std::vector> GenerateRewriteRules(TransformerLevel rules.push_back(onnxruntime::make_unique()); rules.push_back(onnxruntime::make_unique()); rules.push_back(onnxruntime::make_unique()); + rules.push_back(onnxruntime::make_unique()); rules.push_back(onnxruntime::make_unique()); rules.push_back(onnxruntime::make_unique()); rules.push_back(onnxruntime::make_unique()); diff --git a/onnxruntime/core/providers/cuda/math/gemm.cc b/onnxruntime/core/providers/cuda/math/gemm.cc index 87e500cc19..21b771fdbf 100644 --- a/onnxruntime/core/providers/cuda/math/gemm.cc +++ b/onnxruntime/core/providers/cuda/math/gemm.cc @@ -3,7 +3,7 @@ #include "gemm.h" #include "core/providers/cpu/math/gemm_helper.h" -#include "core/providers/cuda/cudnn_common.h" +#include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/cuda/math/softmax_impl.cuh b/onnxruntime/core/providers/cuda/math/softmax_impl.cuh index b9c07da7f2..19508f8538 100644 --- a/onnxruntime/core/providers/cuda/math/softmax_impl.cuh +++ b/onnxruntime/core/providers/cuda/math/softmax_impl.cuh @@ -17,7 +17,6 @@ // The code below is mostly copied from Pytorch PersistentSoftmax.cuh #pragma once - #include "core/providers/cuda/cu_inc/common.cuh" namespace onnxruntime { diff --git a/onnxruntime/core/providers/cuda/nn/shrink.h b/onnxruntime/core/providers/cuda/nn/shrink.h index 68fd27d00d..850dd9781e 100644 --- a/onnxruntime/core/providers/cuda/nn/shrink.h +++ b/onnxruntime/core/providers/cuda/nn/shrink.h @@ -4,7 +4,7 @@ #pragma once #include "gsl/gsl" -#include "core/providers/cuda/cudnn_common.h" +#include "core/providers/cuda/cuda_common.h" namespace onnxruntime { namespace cuda { diff --git a/onnxruntime/core/providers/cuda/tensor/concat.h b/onnxruntime/core/providers/cuda/tensor/concat.h index 7c542709bc..9cefcafb59 100644 --- a/onnxruntime/core/providers/cuda/tensor/concat.h +++ b/onnxruntime/core/providers/cuda/tensor/concat.h @@ -11,7 +11,7 @@ namespace cuda { class Concat final : public CudaKernel, public ConcatBase { public: - Concat(const OpKernelInfo& info) : ConcatBase(info), CudaKernel(info) {} + Concat(const OpKernelInfo& info) : CudaKernel(info), ConcatBase(info) {} Status ComputeInternal(OpKernelContext* context) const override; }; diff --git a/onnxruntime/core/providers/cuda/tensor/gather.h b/onnxruntime/core/providers/cuda/tensor/gather.h index bc7e2508f2..917bff8fc4 100644 --- a/onnxruntime/core/providers/cuda/tensor/gather.h +++ b/onnxruntime/core/providers/cuda/tensor/gather.h @@ -11,7 +11,7 @@ namespace cuda { class Gather final : public CudaKernel, public GatherBase { public: - Gather(const OpKernelInfo& info) : GatherBase(info), CudaKernel(info) {} + Gather(const OpKernelInfo& info) : CudaKernel(info), GatherBase(info) {} Status ComputeInternal(OpKernelContext* context) const override; }; diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_elements_impl.cu b/onnxruntime/core/providers/cuda/tensor/scatter_elements_impl.cu index 83422c14c5..0bf51f741d 100755 --- a/onnxruntime/core/providers/cuda/tensor/scatter_elements_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/scatter_elements_impl.cu @@ -1,13 +1,17 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/providers/cuda/atomic/common.cuh" #include "core/providers/cuda/cu_inc/common.cuh" #include "scatter_elements_impl.h" +#ifdef ENABLE_TRAINING +#include "orttraining/training_ops/cuda/tensor/gather_elements_grad_impl.h" +#endif namespace onnxruntime { namespace cuda { -template +template __global__ void _ScatterElementsKernel2D( const int max_dim, // max dim on the scattered axis const T* input_data, @@ -16,7 +20,8 @@ __global__ void _ScatterElementsKernel2D( const fast_divmod indices_stride_row, const T* updates, const int64_t output_row_size, - T* output_data) { + T* output_data, + const FuncT& func) { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(indices_index, indices_size); int row, col, data_idx; @@ -29,12 +34,13 @@ __global__ void _ScatterElementsKernel2D( } else { data_idx = row * output_row_size + dim; } - output_data[data_idx] = updates[indices_index]; + + func(output_data + data_idx, updates + indices_index); } // else invalid index } -template +template __global__ void _ScatterElementsKernel( const int rank, const T* input_data, @@ -46,7 +52,8 @@ __global__ void _ScatterElementsKernel( const TArray indices_strides, const T* updates, const int axis, - T* output_data) { + T* output_data, + const FuncT& func) { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(indices_index, indices_size); int dim, remain = indices_index; size_t data_idx = 0; @@ -61,7 +68,7 @@ __global__ void _ScatterElementsKernel( } data_idx += input_strides[i] * dim; } - output_data[data_idx] = updates[indices_index]; + func(output_data + data_idx, updates + indices_index); } // From the innermost axis (largest) check equality of dim value of input and indices. @@ -135,7 +142,7 @@ static int CompactInputIndicesDims( return new_axis; } -template +template Status ScatterElementsImpl2D( const T* input_data, const std::vector& input_dims, @@ -144,23 +151,70 @@ Status ScatterElementsImpl2D( const std::vector& indices_dims, const T* updates, const int axis, - T* output_data) { + T* output_data, + const FuncT& func) { int blocksPerGrid = gsl::narrow_cast(CeilDiv(indices_size, GridDim::maxThreadsPerBlock)); fast_divmod indices_stride_row(indices_dims[1]); if (axis == 0) { - _ScatterElementsKernel2D<<>>( + _ScatterElementsKernel2D<<>>( gsl::narrow_cast(input_dims[0]), input_data, indices_data, indices_size, indices_stride_row, - updates, input_dims[1], output_data); + updates, input_dims[1], output_data, func); } else { - _ScatterElementsKernel2D<<>>( + _ScatterElementsKernel2D<<>>( gsl::narrow_cast(input_dims[1]), input_data, indices_data, indices_size, indices_stride_row, - updates, input_dims[1], output_data); + updates, input_dims[1], output_data, func); } return Status::OK(); } +template +Status ScatterElementsImplInternal( + const int rank, + const T* input_data, + const int64_t input_size, + TArray& buffer_input_dims, + TArray& buffer_input_strides, + const Tin* indices_data, + const int64_t indices_size, + TArray& buffer_indices_dims, + TArray& fdm_indices_strides, + const T* updates, + const int axis, + T* output_data, + const FuncT& func) { + if (input_data != output_data) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_data, input_data, input_size * sizeof(T), cudaMemcpyDeviceToDevice, 0)); + } + + if (indices_size > 0) { + std::vector eff_input_dims; + std::vector eff_indices_dims; + int new_axis = CompactInputIndicesDims( + rank, axis, buffer_input_dims.data_, buffer_indices_dims.data_, eff_input_dims, eff_indices_dims); + if (eff_input_dims.size() == 2) { + return ScatterElementsImpl2D( + input_data, eff_input_dims, indices_data, indices_size, eff_indices_dims, updates, new_axis, output_data, + func); + } + + int blocksPerGrid = gsl::narrow_cast(CeilDiv(indices_size, GridDim::maxThreadsPerBlock)); + _ScatterElementsKernel<<>>( + rank, input_data, buffer_input_dims, buffer_input_strides, + indices_data, indices_size, buffer_indices_dims, fdm_indices_strides, + updates, axis, output_data, func); + } + return Status::OK(); +} + +template +struct Func_Assignment { + __device__ __inline__ void operator()(T* a, const T* b) const { + *a = *b; + } +}; + template Status ScatterElementsImpl( const int rank, @@ -175,60 +229,94 @@ Status ScatterElementsImpl( const T* updates, const int axis, T* output_data) { - if (input_data != output_data) { - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_data, input_data, input_size * sizeof(T), cudaMemcpyDeviceToDevice, 0)); - } - - if (indices_size > 0) { - std::vector eff_input_dims; - std::vector eff_indices_dims; - int new_axis = CompactInputIndicesDims( - rank, axis, buffer_input_dims.data_, buffer_indices_dims.data_, eff_input_dims, eff_indices_dims); - if (eff_input_dims.size() == 2) { - return ScatterElementsImpl2D( - input_data, eff_input_dims, indices_data, indices_size, eff_indices_dims, updates, new_axis, output_data); - } - - int blocksPerGrid = gsl::narrow_cast(CeilDiv(indices_size, GridDim::maxThreadsPerBlock)); - _ScatterElementsKernel<<>>( - rank, input_data, buffer_input_dims, buffer_input_strides, - indices_data, indices_size, buffer_indices_dims, fdm_indices_strides, - updates, axis, output_data); - } - return Status::OK(); + return ScatterElementsImplInternal(rank, input_data, input_size, buffer_input_dims, + buffer_input_strides, indices_data, indices_size, buffer_indices_dims, fdm_indices_strides, + updates, axis, output_data, Func_Assignment()); } -#define SPECIALIZED_TINDEX_IMPL(T, TIndex) \ - template Status ScatterElementsImpl( \ - const int rank, \ - const T* input_data, \ - const int64_t input_size, \ - TArray& buffer_input_dims, \ - TArray& buffer_input_strides, \ - const TIndex* indices_data, \ - const int64_t indices_size, \ - TArray& buffer_indices_dims, \ - TArray& indices_strides, \ - const T* updates, \ - const int axis, \ +#define SCATTER_ELEMENTS_SPECIALIZED_TINDEX_IMPL(T, TIndex) \ + template Status ScatterElementsImpl( \ + const int rank, \ + const T* input_data, \ + const int64_t input_size, \ + TArray& buffer_input_dims, \ + TArray& buffer_input_strides, \ + const TIndex* indices_data, \ + const int64_t indices_size, \ + TArray& buffer_indices_dims, \ + TArray& indices_strides, \ + const T* updates, \ + const int axis, \ T* output_data) -#define SPECIALIZED_IMPL(T) \ - SPECIALIZED_TINDEX_IMPL(T, int32_t); \ - SPECIALIZED_TINDEX_IMPL(T, int64_t); +#define SCATTER_ELEMENTS_SPECIALIZED_IMPL(T) \ + SCATTER_ELEMENTS_SPECIALIZED_TINDEX_IMPL(T, int32_t); \ + SCATTER_ELEMENTS_SPECIALIZED_TINDEX_IMPL(T, int64_t); -SPECIALIZED_IMPL(int8_t) -SPECIALIZED_IMPL(int16_t) -SPECIALIZED_IMPL(int32_t) -SPECIALIZED_IMPL(int64_t) -SPECIALIZED_IMPL(uint8_t) -SPECIALIZED_IMPL(uint16_t) -SPECIALIZED_IMPL(uint32_t) -SPECIALIZED_IMPL(uint64_t) -SPECIALIZED_IMPL(half) -SPECIALIZED_IMPL(float) -SPECIALIZED_IMPL(double) -SPECIALIZED_IMPL(bool) +SCATTER_ELEMENTS_SPECIALIZED_IMPL(int8_t) +SCATTER_ELEMENTS_SPECIALIZED_IMPL(int16_t) +SCATTER_ELEMENTS_SPECIALIZED_IMPL(int32_t) +SCATTER_ELEMENTS_SPECIALIZED_IMPL(int64_t) +SCATTER_ELEMENTS_SPECIALIZED_IMPL(uint8_t) +SCATTER_ELEMENTS_SPECIALIZED_IMPL(uint16_t) +SCATTER_ELEMENTS_SPECIALIZED_IMPL(uint32_t) +SCATTER_ELEMENTS_SPECIALIZED_IMPL(uint64_t) +SCATTER_ELEMENTS_SPECIALIZED_IMPL(half) +SCATTER_ELEMENTS_SPECIALIZED_IMPL(float) +SCATTER_ELEMENTS_SPECIALIZED_IMPL(double) +SCATTER_ELEMENTS_SPECIALIZED_IMPL(bool) + +#ifdef ENABLE_TRAINING + +template +struct Func_AtomicAdd { + __device__ __inline__ void operator()(T* a, const T* b) const { + atomic_add(a, *b); + } +}; + +template +Status GatherElementsGradImpl( + const int rank, + TArray& buffer_input_dims, + TArray& buffer_input_strides, + const Tin* indices_data, + const int64_t indices_size, + TArray& buffer_indices_dims, + TArray& fdm_indices_strides, + const T* updates, + const int axis, + T* output_data) { + // Give output_data as the input_data parameter by intention, + // to skip input_data copy, which is not applicable for GatherElementsGrad. + return ScatterElementsImplInternal(rank, output_data, 0, + buffer_input_dims, buffer_input_strides, indices_data, + indices_size, buffer_indices_dims, fdm_indices_strides, + updates, axis, output_data, Func_AtomicAdd()); +} + +#define GATHER_ELEMENTS_GRAD_SPECIALIZED_TINDEX_IMPL(T, TIndex) \ + template Status GatherElementsGradImpl( \ + const int rank, \ + TArray& buffer_input_dims, \ + TArray& buffer_input_strides, \ + const TIndex* indices_data, \ + const int64_t indices_size, \ + TArray& buffer_indices_dims, \ + TArray& indices_strides, \ + const T* updates, \ + const int axis, \ + T* output_data) + +#define GATHER_ELEMENTS_GRAD_SPECIALIZED_SCATTER_ADD_IMPL(T) \ + GATHER_ELEMENTS_GRAD_SPECIALIZED_TINDEX_IMPL(T, int32_t); \ + GATHER_ELEMENTS_GRAD_SPECIALIZED_TINDEX_IMPL(T, int64_t); + +GATHER_ELEMENTS_GRAD_SPECIALIZED_SCATTER_ADD_IMPL(half) +GATHER_ELEMENTS_GRAD_SPECIALIZED_SCATTER_ADD_IMPL(float) +GATHER_ELEMENTS_GRAD_SPECIALIZED_SCATTER_ADD_IMPL(double) + +#endif } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_elements_impl.h b/onnxruntime/core/providers/cuda/tensor/scatter_elements_impl.h index 2f08d542e0..5eea6ab808 100755 --- a/onnxruntime/core/providers/cuda/tensor/scatter_elements_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/scatter_elements_impl.h @@ -26,4 +26,3 @@ Status ScatterElementsImpl( } // namespace cuda } // namespace onnxruntime - diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index bc7926644c..27200ec41a 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -40,6 +40,7 @@ #include "core/optimizer/reshape_fusion.h" #include "core/optimizer/attention_fusion.h" #include "core/optimizer/fast_gelu_fusion.h" +#include "core/optimizer/expand_elimination.h" #include "core/optimizer/cast_elimination.h" #include "core/optimizer/utils.h" #include "core/platform/env.h" @@ -1115,6 +1116,24 @@ TEST_F(GraphTransformationTests, ReshapeFusionInternalReuseTest) { } } +TEST_F(GraphTransformationTests, ExpandElimination) { + auto model_uri = MODEL_FOLDER "expand_elimination.onnx"; + std::shared_ptr model; + ASSERT_TRUE(Model::Load(model_uri, model, nullptr, *logger_).IsOK()); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Expand"] == 6); + + auto rule_transformer_L1 = onnxruntime::make_unique("RuleTransformer1"); + rule_transformer_L1->Register(onnxruntime::make_unique()); + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); + ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_).IsOK()); + + op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Expand"] == 3); +} + TEST_F(GraphTransformationTests, CastElimination) { auto model_uri = MODEL_FOLDER "cast_elimination.onnx"; std::shared_ptr model; diff --git a/onnxruntime/test/testdata/test_training_model_0.onnx b/onnxruntime/test/testdata/test_training_model_0.onnx new file mode 100644 index 0000000000..2991cb758c Binary files /dev/null and b/onnxruntime/test/testdata/test_training_model_0.onnx differ diff --git a/onnxruntime/test/testdata/test_training_model_1.onnx b/onnxruntime/test/testdata/test_training_model_1.onnx new file mode 100644 index 0000000000..d7041e03bd Binary files /dev/null and b/onnxruntime/test/testdata/test_training_model_1.onnx differ diff --git a/onnxruntime/test/testdata/test_training_model_2.onnx b/onnxruntime/test/testdata/test_training_model_2.onnx new file mode 100644 index 0000000000..d0958df31a Binary files /dev/null and b/onnxruntime/test/testdata/test_training_model_2.onnx differ diff --git a/onnxruntime/test/testdata/transform/expand_elimination.onnx b/onnxruntime/test/testdata/transform/expand_elimination.onnx new file mode 100644 index 0000000000..7a42e5297f Binary files /dev/null and b/onnxruntime/test/testdata/transform/expand_elimination.onnx differ diff --git a/onnxruntime/test/testdata/transform/expand_elimination.py b/onnxruntime/test/testdata/transform/expand_elimination.py new file mode 100644 index 0000000000..debe0a9908 --- /dev/null +++ b/onnxruntime/test/testdata/transform/expand_elimination.py @@ -0,0 +1,56 @@ +import onnx +from onnx import helper +from onnx import TensorProto, GraphProto, OperatorSetIdProto +from onnx import numpy_helper +import numpy as np + +X1 = helper.make_tensor_value_info('input1', TensorProto.FLOAT, [2, 1]) +X2 = helper.make_tensor_value_info('input2', TensorProto.FLOAT, ['dynamic', 4]) +Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 2, 4]) + +shape_constant1 = numpy_helper.from_array(np.array([1, 4], dtype=np.int64), name='shape_constant1') +shape_constant2 = numpy_helper.from_array(np.array([1, 1], dtype=np.int64), name='shape_constant2') +shape_constant3 = numpy_helper.from_array(np.array([2, 1], dtype=np.int64), name='shape_constant3') +shape_constant4 = numpy_helper.from_array(np.array([1, 1, 1], dtype=np.int64), name='shape_constant4') +shape_constant5 = numpy_helper.from_array(np.array([1, 4], dtype=np.int64), name='shape_constant5') +shape_constant6 = numpy_helper.from_array(np.array([2, 1], dtype=np.int64), name='shape_constant6') + +identity1 = helper.make_node('Identity', ['input1'], ['identity1'], name='identity1') +expand1 = helper.make_node('Expand', ['identity1', shape_constant1.name], ['expand1'], name='expand1') +expand2 = helper.make_node('Expand', ['identity1', shape_constant2.name], ['expand2'], name='expand2') +mul1 = helper.make_node('Mul', ['expand1', 'expand2'], ['mul1'], name='mul1') # (2, 4) +expand3 = helper.make_node('Expand', ['mul1', shape_constant3.name], ['expand3'], name='expand3') +expand4 = helper.make_node('Expand', ['identity1', shape_constant4.name], ['expand4'], name='expand4') +mul2 = helper.make_node('Mul', ['expand3', 'expand4'], ['mul2'], name='mul2') # (1, 2, 4) +identity2 = helper.make_node('Identity', ['input2'], ['identity2'], name='identity2') +expand5 = helper.make_node('Expand', ['identity2', shape_constant5.name], ['expand5'], name='expand5') +expand6 = helper.make_node('Expand', ['identity2', shape_constant6.name], ['expand6'], name='expand6') +mul3 = helper.make_node('Mul', ['expand5', 'expand6'], ['mul3'], name='mul3') # (dynamic=2, 4) +mul4 = helper.make_node('Mul', ['mul2', 'mul3'], ['output'], name='mul4') + +# Create the graph (GraphProto) +graph_def = helper.make_graph( + [identity1, expand1, expand2, mul1, expand3, expand4, mul2, identity2, expand5, expand6, mul3, mul4], + 'expand_elimination_model', + [X1, X2], + [Y], + [shape_constant1, shape_constant2, shape_constant3, shape_constant4, shape_constant5, shape_constant6] +) + +opsets = [] +onnxdomain = OperatorSetIdProto() +onnxdomain.version = 12 +onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. +opsets.append(onnxdomain) + +msdomain = OperatorSetIdProto() +msdomain.version = 1 +msdomain.domain = 'com.microsoft' + +opsets.append(msdomain) +kwargs={} +kwargs['opset_imports'] = opsets + +# Create the model (ModelProto) +model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs) +onnx.save(model_def, 'expand_elimination.onnx') diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 76b0e1aa0c..609ff30304 100644 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -573,6 +573,17 @@ IMPLEMENT_GRADIENT_BUILDER(GetGatherGradient) { SrcNodeAttributes())}; } +IMPLEMENT_GRADIENT_BUILDER(GetGatherElementsGradient) { + return std::vector{ + NodeDef("Shape", + {I(0)}, + {IA("x_shape")}), + NodeDef(OpDef{"GatherElementsGrad", kMSDomain, 1}, + {GO(0), IA("x_shape"), I(1)}, + {GI(0)}, + SrcNodeAttributes())}; +}; + IMPLEMENT_GRADIENT_BUILDER(GetReluGradient) { return std::vector{ NodeDef("ReluGrad", @@ -1034,7 +1045,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetSendGradient) { } return std::vector{ - NodeDef("Recv", + NodeDef(OpDef{"Recv", kMSDomain, 1}, {O(0), I(1)}, // {Signal, Remote} out_args, SrcNodeAttributes())}; @@ -1046,18 +1057,38 @@ IMPLEMENT_GRADIENT_BUILDER(GetRecvGradient) { std::vector in_args; in_args.push_back(O(0)); // Signal - in_args.push_back(I(0)); // Remote + in_args.push_back(I(1)); // Remote for (int i = 1; i < GetSrcNodeOutputSize(); ++i) { in_args.push_back(GO(i)); // Data } return std::vector{ - NodeDef("Send", + NodeDef(OpDef{"Send", kMSDomain, 1}, in_args, {GI(0)}, // Signal SrcNodeAttributes())}; } +IMPLEMENT_GRADIENT_BUILDER(GetExpandGradient) { + ArgDef a = I(0), y = O(0); + std::vector a_shape = GetShape(a); + std::vector y_shape = GetShape(y); + std::vector a_axes; + ComputeBroadcastBackwardAxes(a_shape, y_shape, &a_axes, nullptr); + + std::vector output; + if (a_axes.size() > 0) { + HandleBroadcasting(GO(0), a, GI(0), a_axes, output); + } else { + output.push_back( + NodeDef("Identity", + {GO(0)}, + {GI(0)})); + } + + return output; +} + } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/gradient_builder.h b/orttraining/orttraining/core/graph/gradient_builder.h index a2f67c76e8..465209da84 100644 --- a/orttraining/orttraining/core/graph/gradient_builder.h +++ b/orttraining/orttraining/core/graph/gradient_builder.h @@ -45,6 +45,7 @@ DECLARE_GRADIENT_BUILDER(GetGemmGradient) DECLARE_GRADIENT_BUILDER(GetDropoutGradient) DECLARE_GRADIENT_BUILDER(GetTrainableDropoutGradient) DECLARE_GRADIENT_BUILDER(GetGatherNDGradient) +DECLARE_GRADIENT_BUILDER(GetGatherElementsGradient) DECLARE_GRADIENT_BUILDER(GetGeluGradient) DECLARE_GRADIENT_BUILDER(GetLayerNormalizationGradient) DECLARE_GRADIENT_BUILDER(GetBatchNormalizationGradient) @@ -55,6 +56,7 @@ DECLARE_GRADIENT_BUILDER(GetFastGeluGradient) DECLARE_GRADIENT_BUILDER(GetWhereGradient) DECLARE_GRADIENT_BUILDER(GetSendGradient) DECLARE_GRADIENT_BUILDER(GetRecvGradient) +DECLARE_GRADIENT_BUILDER(GetExpandGradient) } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/gradient_builder_base.cc b/orttraining/orttraining/core/graph/gradient_builder_base.cc index ca0aa9bb54..3a63f2f8fa 100644 --- a/orttraining/orttraining/core/graph/gradient_builder_base.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_base.cc @@ -61,6 +61,7 @@ void ComputeBroadcastBackwardAxes( } std::vector GetShape(const ArgDef& arg_def) { + ORT_ENFORCE(arg_def.type_proto, "During GetShape, ", arg_def.name, "'s type_proto is null."); std::vector shape; const auto& dims = arg_def.type_proto->tensor_type().shape().dim(); for (auto dim = dims.begin(); dim < dims.end(); dim++) { diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc index 90acb6dff3..e867a00c40 100644 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc @@ -83,6 +83,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() { REGISTER_GRADIENT_BUILDER("Dropout", GetDropoutGradient) REGISTER_GRADIENT_BUILDER("TrainableDropout", GetTrainableDropoutGradient) REGISTER_GRADIENT_BUILDER("GatherND", GetGatherNDGradient) + REGISTER_GRADIENT_BUILDER("GatherElements", GetGatherElementsGradient) REGISTER_GRADIENT_BUILDER("Gelu", GetGeluGradient) REGISTER_GRADIENT_BUILDER("LayerNormalization", GetLayerNormalizationGradient); REGISTER_GRADIENT_BUILDER("BatchNormalization", GetBatchNormalizationGradient); @@ -93,6 +94,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() { REGISTER_GRADIENT_BUILDER("Where", GetWhereGradient); REGISTER_GRADIENT_BUILDER("Send", GetSendGradient); REGISTER_GRADIENT_BUILDER("Recv", GetRecvGradient); + REGISTER_GRADIENT_BUILDER("Expand", GetExpandGradient); }; } // namespace training diff --git a/orttraining/orttraining/core/graph/gradient_schema_defs.cc b/orttraining/orttraining/core/graph/gradient_schema_defs.cc index 1a747fe332..9249dd8410 100644 --- a/orttraining/orttraining/core/graph/gradient_schema_defs.cc +++ b/orttraining/orttraining/core/graph/gradient_schema_defs.cc @@ -432,6 +432,43 @@ void RegisterGradientSchemas() { {"tensor(int32)", "tensor(int64)"}, "Constrain indices to integer types"); + ONNX_CONTRIB_OPERATOR_SCHEMA(GatherElementsGrad) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL) + .SetDoc("GatherElementsGrad") + .Attr( + "axis", + "Which axis to scatter on. Negative value means " + "counting dimensions from the back. Accepted range is [-r, r-1] where r = rank(data).", + AttributeProto::INT, + static_cast(0)) + .Input( + 0, + "dY", + "Tensor of rank r >=1 (same rank and shape as indices)", + "T") + .Input(1, "shape", "Shape of the GatherElements input data.", "I") + .Input( + 2, + "indices", + "Tensor of int32/int64 indices, of r >= 1 (same rank as input). All index values are expected to be " + "within bounds [-s, s-1] along axis of size s. It is an error if any of the index values are out of bounds.", + "Tind") + .Output(0, "dX", "Tensor of rank r >= 1 (same rank as input).", "T") + .TypeConstraint( + "I", + {"tensor(int64)"}, + "Constrain input shape to integer tensors.") + .TypeConstraint( + "T", + {"tensor(float16)", "tensor(float)", "tensor(double)"}, + "Input and output types can be of any tensor type.") + .TypeConstraint( + "Tind", + {"tensor(int32)", "tensor(int64)"}, + "Constrain indices to integer types"); + ONNX_CONTRIB_OPERATOR_SCHEMA(DivGrad) .SetDomain(kMSDomain) .SinceVersion(1) @@ -1633,7 +1670,7 @@ Return true if all elements are true and false otherwise. "Allow inputs and outputs to be any kind of tensor.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { if (ctx.getNumInputs() < ctx.getNumOutputs() + 1) - fail_shape_inference("WaitEvent must have at least (num_outputs + 1) inputs."); + fail_shape_inference("RecordEvent must have at least (num_outputs + 1) inputs."); // note: if num_input > num_output + 1, // the additional inputs (idx >= num_ouput + 1) are regarded as dependencies diff --git a/orttraining/orttraining/core/graph/pipeline_transformer.cc b/orttraining/orttraining/core/graph/pipeline_transformer.cc new file mode 100644 index 0000000000..39700c6e9b --- /dev/null +++ b/orttraining/orttraining/core/graph/pipeline_transformer.cc @@ -0,0 +1,279 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/core/graph/pipeline_transformer.h" + +using namespace onnxruntime::common; + +namespace onnxruntime { +namespace training { + +void GetPipelineSendOutput(const Graph& graph, std::string& loss_name) { + for (auto& node : graph.Nodes()) { + if (!node.OpType().compare("Send")) { + // send op should always have an output, which is the OutputSignal. + loss_name = node.OutputDefs()[0]->Name(); + return; + } + } +} + +bool IsBackward(Node& node) { + return (node.Description() == "Backward pass"); +} + +void AddInputEvent(Graph& graph, const std::string& op_name, + bool is_forward, + std::vector& input_args, + std::vector& new_input_names) { + ONNX_NAMESPACE::TypeProto event_type_proto; + event_type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + + auto event_id_name = graph.GenerateNodeArgName(op_name + (is_forward ? "_fw" : "_bw") + "_event_id"); + auto& event_id = graph.GetOrCreateNodeArg(event_id_name, &event_type_proto); + new_input_names.push_back(event_id_name); + + input_args.push_back(&event_id); +} + +// gradient graph can contain some dangling leaf nodes. Add them all to WaitEvent +// backward node's input. +void FindLeafNodes(Graph& graph, std::vector& input_args) { + for (auto& node : graph.Nodes()) { + if (!IsBackward(node)) { + // only check backward node + continue; + } + bool find_consumer_nodes = false; + std::vector& outputs = node.MutableOutputDefs(); + for (auto& output : outputs) { + std::vector consumer_nodes = graph.GetConsumerNodes(output->Name()); + if (consumer_nodes.size() > 0) { + find_consumer_nodes = true; + break; + } + } + if (!find_consumer_nodes && outputs.size() > 0) { + input_args.push_back(outputs[0]); + } + } +}; + +NodeArg& CreateNodeArg(Graph& graph, const NodeArg* base_arg) { + const auto& new_name = graph.GenerateNodeArgName(base_arg->Name()); + ONNX_NAMESPACE::TypeProto type_proto(*(base_arg->TypeAsProto())); + if (graph.GetNodeArg(new_name) != nullptr) { + ORT_THROW("Node with name ", new_name, " already exists."); + } + return graph.GetOrCreateNodeArg(new_name, &type_proto); +} + +Status AddRecordBackward(Graph& graph, + Node* send_bw, + std::vector& new_input_names, + std::vector& new_output_names) { + std::vector input_args; + AddInputEvent(graph, "RecordEvent", false /* is_forward */, input_args, new_input_names); + std::vector output_args{}; + + if (send_bw) { + // if we have send op in backward pass (at the end of the graph), we make sure the RecordEvent happens + // after that send by adding Send's outputs to RecordEvent's input list. + input_args.insert(std::end(input_args), + std::begin(send_bw->MutableOutputDefs()), + std::end(send_bw->MutableOutputDefs())); + } + FindLeafNodes(graph, input_args); + + // Optimizer will be added after applying pipeline transformer. To support partial graph evaluation, + // the added Record backward op will have its first passthrough input as output. + ORT_RETURN_IF_NOT(input_args.size() >= 2, "RecordEvent backward op at least have two inputs.") + auto& new_output = CreateNodeArg(graph, input_args[1]); // the first input is signal, not passing through + output_args.push_back(&new_output); + new_output_names.push_back(new_output.Name()); + + graph.AddNode(graph.GenerateNodeName("RecordEvent"), + "RecordEvent", + "Backward pass", + input_args, + output_args, + nullptr, + kMSDomain); + return Status::OK(); +} + +Status AddWaitForward(Graph& graph, Node* /* recv_fw */, std::vector& new_input_names) { + // Append old_input to input_args and return its pass-through value. Note that + // input_args and output_args are Wait's inputs and outputs, respectively. + auto update_wait_input_output = [&](NodeArg* old_input, + std::vector& input_args, + std::vector& output_args) -> NodeArg& { + input_args.push_back(old_input); + + const auto& new_name = graph.GenerateNodeArgName(old_input->Name()); + ONNX_NAMESPACE::TypeProto type_proto(*(old_input->TypeAsProto())); + + auto& wait_output = graph.GetOrCreateNodeArg(new_name, &type_proto); + output_args.push_back(&wait_output); + + return wait_output; + }; + + std::vector input_args; + std::vector output_args; + AddInputEvent(graph, "WaitEvent", true /* is_forward */, input_args, new_input_names); + const std::vector& graph_inputs = graph.GetInputsIncludingInitializers(); + + if (graph_inputs.size() == 0){ + ORT_THROW("Graph ", graph.Name(), " doesn't have any inputs."); + } + + for (auto& input_arg : graph_inputs) { + NodeArg* mutable_input = graph.GetNodeArg(input_arg->Name()); + auto& wait_output = update_wait_input_output(mutable_input, input_args, output_args); + std::vector nodes = graph.GetMutableConsumerNodes(input_arg->Name()); + for (auto& consumer_node : nodes) { + for (auto& i : consumer_node->MutableInputDefs()) { + if (i->Name() == input_arg->Name()) { + // if the node is fed by input, re-direct it to be fed by WaitEvent's output. + i = &wait_output; + } + } + } + } + graph.AddNode(graph.GenerateNodeName("WaitEvent"), + "WaitEvent", + "", + input_args, + output_args, + nullptr, + kMSDomain); + + return Status::OK(); +} + +Status AddOrSkipRecordForwardWaitBackward(Graph& graph, Node* send_fw, Node* recv_bw, std::vector& new_input_names) { + if (!send_fw != !recv_bw){ + ORT_THROW("Graph requires either having both send forward node " + "and recv backword node, or none of them. Currently the graph " + "has send forward: ", send_fw, " and recv backward: ", recv_bw); + } + + if (!send_fw && !recv_bw){ + // Last partition doesn't have send forwrad and recv backward. No insert needed. + return Status::OK(); + } + + // if we have a send forward op followed by a recv backward op, insert WaitEvent and RecordEvent in between. + Node* record_node = nullptr; + Node* wait_node = nullptr; + + // Insert RecordEvent + { + std::vector input_args; + std::vector output_args; + AddInputEvent(graph, "RecordEvent", true /* is_forward */, input_args, new_input_names); + + // Add send forward op's output as record op's input and output + for (auto& output : send_fw->MutableOutputDefs()) { + auto& new_output = CreateNodeArg(graph, output); + output_args.push_back(&new_output); + input_args.push_back(output); + } + + auto& new_node = graph.AddNode(graph.GenerateNodeName("RecordEvent"), + "RecordEvent", + "", + input_args, + output_args, /* output */ + {}, /* attribute */ + kMSDomain); + record_node = &new_node; + } + // Insert WaitEvent + { + std::vector input_args; + std::vector output_args; + AddInputEvent(graph, "WaitEvent", false /* is_forward */, input_args, new_input_names); + + input_args.insert(std::end(input_args), + std::begin(record_node->MutableOutputDefs()), + std::end(record_node->MutableOutputDefs())); + + auto& input = recv_bw->MutableInputDefs()[0]; + auto& new_output = CreateNodeArg(graph, input); + output_args.push_back(&new_output); + input = &new_output; + + auto& new_node = graph.AddNode(graph.GenerateNodeName("WaitEvent"), + "WaitEvent", + "Backward pass", + input_args, + output_args, /* output */ + {}, /* attribute */ + kMSDomain); + wait_node = &new_node; + ORT_UNUSED_PARAMETER(wait_node); + } + + return Status::OK(); +} + +Status TransformGraphForPipeline(Graph& graph) { + // insert WaitEvent and RecordEvent to the partition + Node* send_fw{nullptr}; + Node* send_bw{nullptr}; + Node* recv_fw{nullptr}; + Node* recv_bw{nullptr}; + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Send") { + if (IsBackward(node)) { + send_bw = &node; + } else { + send_fw = &node; + } + } else if (node.OpType() == "Recv") { + if (IsBackward(node)) { + recv_bw = &node; + } else { + recv_fw = &node; + } + } + } + + std::vector new_input_names; + std::vector new_output_names; + + ORT_RETURN_IF_ERROR(AddRecordBackward(graph, send_bw, new_input_names, new_output_names)); + ORT_RETURN_IF_ERROR(AddWaitForward(graph, recv_fw, new_input_names)); + ORT_RETURN_IF_ERROR(AddOrSkipRecordForwardWaitBackward(graph, send_fw, recv_bw, new_input_names)); + + auto fill_node_args = [&](const Graph& graph, + const std::vector& existed_node_args, + std::vector& new_node_arg_names, + std::vector& merged_node_args) { + merged_node_args.insert(merged_node_args.end(), existed_node_args.begin(), existed_node_args.end()); + for (auto& name : new_node_arg_names) { + merged_node_args.push_back(graph.GetNodeArg(name)); + } + }; + + const std::vector& graph_inputs = graph.GetInputsIncludingInitializers(); + std::vector inputs_args_sets; + inputs_args_sets.reserve(graph_inputs.size() + new_input_names.size()); + fill_node_args(graph, graph_inputs, new_input_names, inputs_args_sets); + + const std::vector& graph_outputs = graph.GetOutputs(); + std::vector outputs_args_sets; + outputs_args_sets.reserve(graph_outputs.size() + new_output_names.size()); + fill_node_args(graph, graph_outputs, new_output_names, outputs_args_sets); + + graph.SetInputs(inputs_args_sets); + graph.SetOutputs(outputs_args_sets); + graph.SetGraphResolveNeeded(); + graph.SetGraphProtoSyncNeeded(); + return graph.Resolve(); +} + +} // namespace training +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/pipeline_transformer.h b/orttraining/orttraining/core/graph/pipeline_transformer.h new file mode 100644 index 0000000000..11519024e1 --- /dev/null +++ b/orttraining/orttraining/core/graph/pipeline_transformer.h @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/graph/graph.h" + +namespace onnxruntime { +namespace training { + +void GetPipelineSendOutput(const Graph& graph, std::string& loss_name); +common::Status TransformGraphForPipeline(Graph& graph); + +} // namespace training +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 27bfdb29cf..d7b76764aa 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -12,6 +12,7 @@ #include "core/optimizer/conv_add_fusion.h" #include "core/optimizer/constant_folding.h" #include "core/optimizer/unsqueeze_elimination.h" +#include "core/optimizer/expand_elimination.h" #include "core/optimizer/cast_elimination.h" #include "core/optimizer/rule_based_graph_transformer.h" #include "core/optimizer/conv_activation_fusion.h" @@ -54,6 +55,7 @@ std::vector> GeneratePreTrainingTransformers(T rule_transformer->Register(make_unique()); rule_transformer->Register(make_unique()); rule_transformer->Register(make_unique()); + rule_transformer->Register(make_unique()); rule_transformer->Register(make_unique()); rule_transformer->Register(make_unique()); diff --git a/orttraining/orttraining/core/session/training_session.cc b/orttraining/orttraining/core/session/training_session.cc index 57e00e5780..117ffcf262 100644 --- a/orttraining/orttraining/core/session/training_session.cc +++ b/orttraining/orttraining/core/session/training_session.cc @@ -15,6 +15,7 @@ #include "core/optimizer/rule_based_graph_transformer.h" #include "orttraining/core/graph/mixed_precision_transformer.h" #include "orttraining/core/graph/tensorboard_transformer.h" +#include "orttraining/core/graph/pipeline_transformer.h" #include "orttraining/core/graph/gradient_builder_base.h" //Gist Encoding @@ -128,15 +129,23 @@ Status TrainingSession::ConfigureForTraining( is_mixed_precision_enabled_ = config.mixed_precision_config.has_value(); std::string loss_name{}; - const optional loss_function_info = - config.loss_function_config.has_value() - ? config.loss_function_config.value().loss_function_info - : optional{}; optional loss_scale_input_name = - is_mixed_precision_enabled_ ? optional{""} : optional{}; - ORT_RETURN_IF_ERROR(ConfigureLossFunction( - config.loss_name, loss_function_info, - loss_scale_input_name.has_value() ? &loss_scale_input_name.value() : nullptr, loss_name)); + is_mixed_precision_enabled_ ? optional{""} : optional{}; + if (config.use_pipeline) { + // if use pipeline, first check if model contains send op. If it does, set the + // send node's output as the start tensor to build gradient graph + GetPipelineSendOutput(model_->MainGraph(), loss_name); + } + if (loss_name.empty()) { + const optional loss_function_info = + config.loss_function_config.has_value() + ? config.loss_function_config.value().loss_function_info + : optional{}; + ORT_RETURN_IF_ERROR(ConfigureLossFunction( + config.loss_name, loss_function_info, + loss_scale_input_name.has_value() ? &loss_scale_input_name.value() : nullptr, loss_name)); + } + ORT_ENFORCE( !loss_scale_input_name.has_value() || !loss_scale_input_name.value().empty(), "loss_scale_input_name should not be set to an empty string."); @@ -170,7 +179,6 @@ Status TrainingSession::ConfigureForTraining( << weight_names_stream.str(); } - // add gradient graph ORT_RETURN_IF_ERROR(BuildGradientGraph( weight_names_to_train, loss_name, config.set_gradients_as_graph_outputs)); @@ -182,12 +190,30 @@ Status TrainingSession::ConfigureForTraining( weight_names_to_train, mixed_precision_config.use_fp16_initializers, fp32_weight_name_to_fp16_node_arg)); } + if (config.use_pipeline) { + ORT_RETURN_IF_ERROR(InsertPipelineOps()); + } + + // All non-float tensors are not trainable. Remove those weights. + // TODO: this is a temp workaround for removing rank tensor before adding optimizer. + // Re-visit after we port logic for model splitting and hence know the rank tensor name. + for (auto it = weights_to_train_.begin(); it != weights_to_train_.end();) { + const auto* node_arg = model_->MainGraph().GetNodeArg(*it); + ORT_RETURN_IF_NOT(node_arg, "Failed to get NodeArg with name ", *it); + if (node_arg->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + it = weights_to_train_.erase(it); + } + else{ + ++it; + } + } + // add optimizer or gradient accumulation if (config.optimizer_config.has_value()) { OptimizerGraphConfig opt_graph_config{}; std::unordered_map opt_node_configs{}; ORT_RETURN_IF_ERROR(SetupOptimizerParams( - weight_names_to_train, fp32_weight_name_to_fp16_node_arg, + weights_to_train_, fp32_weight_name_to_fp16_node_arg, loss_scale_input_name, config, opt_graph_config, opt_node_configs)); TrainingConfigurationResult::OptimizerConfigurationResult optimizer_config_result{}; @@ -198,7 +224,7 @@ Status TrainingSession::ConfigureForTraining( config_result.opt_config_result = optimizer_config_result; } else { if (config.gradient_accumulation_steps > 1) { - ORT_RETURN_IF_ERROR(BuildAccumulationNode(weight_names_to_train)); + ORT_RETURN_IF_ERROR(BuildAccumulationNode(weights_to_train_)); } } @@ -439,6 +465,11 @@ Status TrainingSession::AddTensorboard(const std::string& summary_name, return DoPostLoadProcessing(*model_); } +Status TrainingSession::InsertPipelineOps() { + ORT_RETURN_IF_ERROR(TransformGraphForPipeline(model_->MainGraph())); + return DoPostLoadProcessing(*model_); +} + Status TrainingSession::ConfigureLossFunction( const optional& external_loss_name, const optional& loss_function_info, diff --git a/orttraining/orttraining/core/session/training_session.h b/orttraining/orttraining/core/session/training_session.h index bf06eb43d4..91725cadc1 100644 --- a/orttraining/orttraining/core/session/training_session.h +++ b/orttraining/orttraining/core/session/training_session.h @@ -132,6 +132,9 @@ class TrainingSession : public InferenceSession { // The optimizer configuration. // If not provided, no optimizer is added. optional optimizer_config{}; + + // Whether to use pipeline in training. + bool use_pipeline{false}; }; /** @@ -262,6 +265,7 @@ class TrainingSession : public InferenceSession { const std::vector& norm_nodes, const bool dump_convergence_metrics); + common::Status InsertPipelineOps(); common::Status ApplyTransformationsToMainGraph(); /** configure initial transformers for training */ diff --git a/orttraining/orttraining/models/runner/training_runner.h b/orttraining/orttraining/models/runner/training_runner.h index 23f7672c2e..94d972edca 100644 --- a/orttraining/orttraining/models/runner/training_runner.h +++ b/orttraining/orttraining/models/runner/training_runner.h @@ -169,6 +169,7 @@ class TrainingRunner { common::Status ResetLossScaler(); size_t GetRound() const { return round_; } + TrainingSession& GetSession() { return session_; } private: Status TrainingLoop(IDataLoader& training_data_loader, IDataLoader* test_data_loader); diff --git a/orttraining/orttraining/python/ort_trainer.py b/orttraining/orttraining/python/ort_trainer.py index d9f58e0704..c38b92df83 100644 --- a/orttraining/orttraining/python/ort_trainer.py +++ b/orttraining/orttraining/python/ort_trainer.py @@ -510,6 +510,7 @@ def create_ort_training_session_bind_parameters(model, device, world_rank=-1, wo dtype_torch_to_numpy(torch_params[param].dtype), list(torch_tensor.size()), torch_tensor.data_ptr()) + device_index = get_device_index(device) create_and_bind_grad_or_grad_accumulate_buffer(train_io_binding, torch_tensor, param, enable_grad_accumulation, device, device_index) return session, train_io_binding, eval_io_binding, output_name, torch_params, output_types diff --git a/orttraining/orttraining/test/gradient/gradient_checker.cc b/orttraining/orttraining/test/gradient/gradient_checker.cc index 4f9599e81a..c12fae5cd1 100644 --- a/orttraining/orttraining/test/gradient/gradient_checker.cc +++ b/orttraining/orttraining/test/gradient/gradient_checker.cc @@ -128,7 +128,7 @@ inline Status GradientChecker::ComputeTheoreticalJacobianTransp const size_t dy_size = y_infos[y_idx].shape.Size(); // Compute the theoretical Jacobians one row at a time by back propagating - // '1.0'for each element of 'dy', while holding all other elements of 'dy' at zero. + // '1.0' for each element of 'dy', while holding all other elements of 'dy' at zero. for (int c = 0; c < dy_size; ++c) { // for each value in the dy input vector // clear OpTester input/output/initializer op_session.ClearData(); diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index 22c3a0fd57..8dec1901b6 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -1599,6 +1599,57 @@ TEST(GradientCheckerTest, GatherNDGrad_int32_indice_unique_float_data_axis_2) { EXPECT_IS_TINY(max_error); } +TEST(GradientCheckerTest, GatherElementsGradWithDuplicateUpdate) { + float max_error; + GradientChecker gradient_checker; + OpDef op_def{"GatherElements", kOnnxDomain, 11}; + + TensorInfo data_info({3, 3}, true); + TensorInfo indice_info({2, 3}, false, nullptr, DataTypeImpl::GetTensorType()); + std::vector> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8, 9}, {1, 2, 0, 2, 0, 0}}; + + TensorInfo y_info({2, 3}, true); + int64_t axis = 0; + + gradient_checker.ComputeGradientError(op_def, {data_info, indice_info}, {y_info}, &max_error, x_datas, + {MakeAttribute("axis", axis)}); + EXPECT_IS_TINY(max_error); +} + +TEST(GradientCheckerTest, GatherElementsGradWithoutDuplicateUpdate) { + float max_error; + GradientChecker gradient_checker; + OpDef op_def{"GatherElements", kOnnxDomain, 11}; + + TensorInfo data_info({3, 3}, true); + TensorInfo indice_info({2, 3}, false, nullptr, DataTypeImpl::GetTensorType()); + std::vector> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8, 9}, {1, 1, 1, 2, 2, 2}}; + + TensorInfo y_info({2, 3}, true); + int64_t axis = 0; + + gradient_checker.ComputeGradientError(op_def, {data_info, indice_info}, {y_info}, &max_error, x_datas, + {MakeAttribute("axis", axis)}); + EXPECT_IS_TINY(max_error); +} + +TEST(GradientCheckerTest, GatherElementsGradAxisWithDuplicateUpdate) { + float max_error; + GradientChecker gradient_checker; + OpDef op_def{"GatherElements", kOnnxDomain, 11}; + + TensorInfo data_info({3, 3}, true); + TensorInfo indice_info({2, 3}, false, nullptr, DataTypeImpl::GetTensorType()); + std::vector> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8, 9}, {1, 1, 1, 1, 1, 1}}; + + TensorInfo y_info({2, 3}, true); + int64_t axis = 1; + + gradient_checker.ComputeGradientError(op_def, {data_info, indice_info}, {y_info}, &max_error, x_datas, + {MakeAttribute("axis", axis)}); + EXPECT_IS_TINY(max_error); +} + TEST(GradientCheckerTest, LayerNormGrad) { GradientChecker gradient_checker; { @@ -1800,6 +1851,84 @@ TEST(Synchronization, WaitAndRecordEventMany) { } } +TEST(GradientCheckerTest, ExpandGrad) { + float max_error; + GradientChecker gradient_checker; + OpDef op_def{"Expand"}; + + //input_shape = (2, 3, 1), target_shape = (2, 3, 4) ==> shape(result) = (2, 3, 4) + { + TensorInfo x_info({2, 3, 1}, true); + TensorInfo shape_info({3}, false, nullptr, DataTypeImpl::GetTensorType()); + std::vector> x_datas = {{1, 2, 3, 4, 5, 6}, {2, 3, 4}}; + + TensorInfo y_info({2, 3, 4}, true); + + gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas); + EXPECT_IS_TINY(max_error); + } + + //input_shape = (2, 3, 1), target_shape = (1, 1, 4) ==> shape(result) = (2, 3, 4) + { + TensorInfo x_info({2, 3, 1}, true); + TensorInfo shape_info({3}, false, nullptr, DataTypeImpl::GetTensorType()); + std::vector> x_datas = {{1, 2, 3, 4, 5, 6}, {1, 1, 4}}; + + TensorInfo y_info({2, 3, 4}, true); + + gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas); + EXPECT_IS_TINY(max_error); + } + + //input_shape = (2, 3, 1), target_shape = (4) ==> shape(result) = (2, 3, 4) + { + TensorInfo x_info({2, 3, 1}, true); + TensorInfo shape_info({1}, false, nullptr, DataTypeImpl::GetTensorType()); + std::vector> x_datas = {{1, 2, 3, 4, 5, 6}, {4}}; + + TensorInfo y_info({2, 3, 4}, true); + + gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas); + EXPECT_IS_TINY(max_error); + } + + //input_shape = (2, 3, 1), target_shape = (1, 1) ==> shape(result) = (2, 3, 1) + { + TensorInfo x_info({2, 3, 1}, true); + TensorInfo shape_info({2}, false, nullptr, DataTypeImpl::GetTensorType()); + std::vector> x_datas = {{1, 2, 3, 4, 5, 6}, {1, 1}}; + + TensorInfo y_info({2, 3, 1}, true); + + gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas); + EXPECT_IS_TINY(max_error); + } + + //input_shape = (2, 3), target_shape = (4, 5, 2, 3) ==> shape(result) = (4, 5, 2, 3) + { + TensorInfo x_info({2, 3}, true); + TensorInfo shape_info({4}, false, nullptr, DataTypeImpl::GetTensorType()); + std::vector> x_datas = {{1, 2, 3, 4, 5, 6}, {4, 5, 2, 3}}; + + TensorInfo y_info({4, 5, 2, 3}, true); + + gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas); + EXPECT_IS_TINY(max_error); + } + + //input_shape = (1, 2, 3), target_shape = (4, 5, 1, 1) ==> shape(result) = (4, 5, 2, 3) + { + TensorInfo x_info({1, 2, 3}, true); + TensorInfo shape_info({4}, false, nullptr, DataTypeImpl::GetTensorType()); + std::vector> x_datas = {{1, 2, 3, 4, 5, 6}, {4, 5, 1, 1}}; + + TensorInfo y_info({4, 5, 2, 3}, true); + + gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas); + EXPECT_IS_TINY(max_error); + } +} + } // namespace test } // namespace onnxruntime diff --git a/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc b/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc index 1c1a63bbe5..7d6b015219 100644 --- a/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc +++ b/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc @@ -997,6 +997,106 @@ class PipelineBatchPlanner { } }; +// verify pipeline config can load and gradient graph can construct. +TEST(GradientGraphBuilderTest, TrainingSession_PipelineTransform_base) { + PathString filename_base = ORT_TSTR("testdata/test_training_model_"); + + auto load_gradient_graph = [](int stageIdx, PathString& input_file, PathString& output_file) { + auto config = MakeBasicTrainingConfig(); + + config.use_pipeline = true; + + PathString backprop_model_file; + ASSERT_STATUS_OK(BuildBackPropGraph(input_file, config, backprop_model_file)); + + std::shared_ptr model; + ASSERT_TRUE(Model::Load(backprop_model_file, model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + + Graph& graph = model->MainGraph(); + auto is_backward = [](Node& node) { + return (node.Description() == "Backward pass"); + }; + // check for wait/record node + Node* wait_fw{nullptr}; + Node* wait_bw{nullptr}; + Node* record_fw{nullptr}; + Node* record_bw{nullptr}; + for (auto& node : graph.Nodes()) { + if (node.OpType() == "WaitEvent") { + if (is_backward(node)) { + wait_bw = &node; + } else { + wait_fw = &node; + } + } else if (node.OpType() == "RecordEvent") { + if (is_backward(node)) { + record_bw = &node; + } else { + record_fw = &node; + } + } + } + // every partition should have wait forward and record backward + ASSERT_TRUE(wait_fw && record_bw); + if (stageIdx == 2) { + // the last partition can perform back prop right away. It won't have record + // forward and wait backward + ASSERT_TRUE(!record_fw && !wait_bw); + } else { + ASSERT_TRUE(record_fw && wait_bw); + } + + // check for send/recv node + Node* send_fw{nullptr}; + Node* send_bw{nullptr}; + Node* recv_fw{nullptr}; + Node* recv_bw{nullptr}; + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Send") { + if (is_backward(node)) { + send_bw = &node; + } else { + send_fw = &node; + } + } else if (node.OpType() == "Recv") { + if (is_backward(node)) { + recv_bw = &node; + } else { + recv_fw = &node; + } + } + } + // except the last partion, each partition should have send forward and recv backward + if (stageIdx == 0 || stageIdx == 1) { + ASSERT_TRUE(send_fw && recv_bw); + } else { + ASSERT_TRUE(!send_fw && !recv_bw); + } + // except the first partion, each partition should have recv forward and send backward + if (stageIdx == 1 || stageIdx == 2) { + ASSERT_TRUE(recv_fw && send_bw); + } else { + ASSERT_TRUE(!recv_fw && !send_bw); + } + + auto mp = model->ToProto(); + std::ofstream ofs(output_file, std::ofstream::binary); + mp.SerializeToOstream(&ofs); + ofs.close(); + }; + + for (int i = 0; i < 3; ++i) { +#ifdef _WIN32 + auto surfix = std::to_wstring(i); +#else + auto surfix = std::to_string(i); +#endif + PathString input_file = filename_base + surfix + ORT_TSTR(".onnx"); + PathString output_file = filename_base + surfix + ORT_TSTR("_back.onnx"); + load_gradient_graph(i, input_file, output_file); + } +} + TEST(GradientGraphBuilderTest, TrainingSession_WithPipeline) { auto config = MakeBasicTrainingConfig(); //config.set_gradients_as_graph_outputs = true; diff --git a/orttraining/orttraining/test/training_ops/cuda/gather_elements_grad_test.cc b/orttraining/orttraining/test/training_ops/cuda/gather_elements_grad_test.cc new file mode 100644 index 0000000000..88a41ca884 --- /dev/null +++ b/orttraining/orttraining/test/training_ops/cuda/gather_elements_grad_test.cc @@ -0,0 +1,227 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace cuda { +namespace test { + +TEST(GatherElementsGrad, WithoutAxis) { + onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain); + test.AddInput("dY", {2, 3}, + {1.0f, 1.1f, 1.2f, + 2.0f, 2.1f, 2.2f}); + std::vector data_shape = {3, 3}; + test.AddInput("data_shape", {2}, data_shape); + test.AddInput("indices", {2, 3}, + {1, 0, 2, + 0, 2, 1}); + test.AddOutput("dX", {3, 3}, + {2.0f, 1.1f, 0.0f, + 1.0f, 0.0f, 2.2f, + 0.0f, 2.1f, 1.2f}); + test.Run(); +} + +TEST(GatherElementsGrad, WithAxis) { + onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain); + test.AddAttribute("axis", 1); + test.AddInput("dY", {1, 2}, {1.1f, 2.1f}); + std::vector data_shape = {1, 5}; + test.AddInput("data_shape", {2}, data_shape); + test.AddInput("indices", {1, 2}, {1, 3}); + test.AddOutput("dX", {1, 5}, {0.0f, 1.1f, 0.0f, 2.1f, 0.0f}); + test.Run(); +} + +TEST(GatherElementsGrad, ThreeDimsWithAxis_0) { + onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain); + test.AddAttribute("axis", 0); + + test.AddInput("dY", {1, 3, 3}, + {11.0f, 12.0f, 13.0f, + 14.0f, 15.0f, 16.0f, + 17.0f, 18.0f, 19.0f}); + + std::vector data_shape = {1, 3, 3}; + test.AddInput("data_shape", {3}, data_shape); + + // Because axis 0 is only 1 dimension it should be all zeros + test.AddInput("indices", {1, 3, 3}, + {0, 0, 0, + 0, 0, 0, + 0, 0, 0}); + + test.AddOutput("dX", {1, 3, 3}, + {11.0f, 12.0f, 13.0f, + 14.0f, 15.0f, 16.0f, + 17.0f, 18.0f, 19.0f}); + test.Run(); +} + +TEST(GatherElementsGrad, ThreeDimsWithAxis_2) { + onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain); + test.AddAttribute("axis", 2); + + test.AddInput("dY", {1, 3, 3}, + {11, 12, 13, + 14, 15, 16, + 17, 18, 19}); + + std::vector data_shape = {1, 3, 3}; + test.AddInput("data_shape", {3}, data_shape); + + test.AddInput("indices", {1, 3, 3}, + {2, 1, 0, + 2, 1, 0, + 2, 1, 0}); + + test.AddOutput("dX", {1, 3, 3}, + {13, 12, 11, + 16, 15, 14, + 19, 18, 17}); + test.Run(); +} + +TEST(GatherElementsGrad, NegativeAxis) { + onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain); + test.AddAttribute("axis", -1); + test.AddInput("dY", {1, 2}, {1.1f, 2.1f}); + std::vector data_shape = {1, 5}; + test.AddInput("data_shape", {2}, data_shape); + test.AddInput("indices", {1, 2}, {1, 3}); + test.AddOutput("dX", {1, 5}, {0.0f, 1.1f, 0.0f, 2.1f, 0.0f}); + test.Run(); +} + +TEST(GatherElementsGrad, IndicesUpdatesDontMatch) { + onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain); + test.AddAttribute("axis", 1); + test.AddInput("dY", {1, 2}, {1.1f, 2.1f}); + std::vector data_shape = {1, 5}; + test.AddInput("data_shape", {2}, data_shape); + test.AddInput("indices", {1, 3}, {1, 3, 3}); + test.AddOutput("dX", {1, 5}, {1.0f, 3.1f, 3.0f, 6.1f, 5.0f}); + test.Run(onnxruntime::test::OpTester::ExpectResult::kExpectFailure, "Indices vs dY dimensions differs at position=1 3 vs 2"); +} + +TEST(GatherElementsGrad, ValidAxis) { + onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain); + test.AddAttribute("axis", 0); + test.AddInput("dY", {1, 1, 1}, {5.0f}); + std::vector data_shape = {4, 2, 1}; + test.AddInput("data_shape", {3}, data_shape); + test.AddInput("indices", {1, 1, 1}, {3}); + test.AddOutput("dX", {4, 2, 1}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 5.0f, 0.0f}); + test.Run(); +} + +TEST(GatherElementsGrad, ValidNegativeIndex) { + onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain); + test.AddAttribute("axis", 0); + test.AddInput("dY", {1, 1, 1}, {5.0f}); + std::vector data_shape = {4, 2, 1}; + test.AddInput("data_shape", {3}, data_shape); + test.AddInput("indices", {1, 1, 1}, {-1}); + test.AddOutput("dX", {4, 2, 1}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 5.0f, 0.0f}); + test.Run(); +} + +TEST(GatherElementsGrad, SameUpdateWithoutAxis) { + onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain); + test.AddInput("dY", {2, 2}, + {11.0f, 22.0f, + 33.0f, 44.0f}); + + std::vector data_shape = {3, 3}; + test.AddInput("data_shape", {2}, data_shape); + + test.AddInput("indices", {2, 2}, + {1, 1, + 1, 1}, + true); + + test.AddOutput("dX", {3, 3}, + {0.0f, 0.0f, 0.0f, + 44.0f, 66.0f, 0.0f, + 0.0f, 0.0f, 0.0f}); + test.Run(); +} + +TEST(GatherElementsGrad, SameUpdateWithAxis) { + onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain); + test.AddAttribute("axis", 1); + test.AddInput("dY", {2, 3}, + {11.0f, 22.0f, 33.0f, + 44.0f, 55.0f, 66.0f}); + + std::vector data_shape = {3, 3}; + test.AddInput("data_shape", {2}, data_shape); + + test.AddInput("indices", {2, 3}, + {1, 1, 1, + 1, 1, 1}, + true); + + test.AddOutput("dX", {3, 3}, + {0.0f, 66.0f, 0.0f, + 0.0f, 165.0f, 0.0f, + 0.0f, 0.0f, 0.0f}); + test.Run(); +} + +TEST(GatherElementsGrad, SameUpdateWithNegativeAxis) { + onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain); + test.AddAttribute("axis", -1); + test.AddInput("dY", {2, 3}, + {11.0f, 22.0f, 33.0f, + 44.0f, 55.0f, 66.0f}); + + std::vector data_shape = {3, 3}; + test.AddInput("data_shape", {2}, data_shape); + + test.AddInput("indices", {2, 3}, + {1, 0, 1, + 1, 0, 1}, + true); + + test.AddOutput("dX", {3, 3}, + {22.0f, 44.0f, 0.0f, + 55.0f, 110.0f, 0.0f, + 0.0f, 0.0f, 0.0f}); + test.Run(); +} + +TEST(GatherElementsGrad, SameUpdateWithoutAxisMLFloat16) { + onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain); + std::vector update = {11.0f, 22.0f, + 33.0f, 44.0f}; + std::vector fp16_update(update.size()); + onnxruntime::test::ConvertFloatToMLFloat16(update.data(), fp16_update.data(), static_cast(update.size())); + + std::vector output = {0.0f, 0.0f, 0.0f, + 44.0f, 66.0f, 0.0f, + 0.0f, 0.0f, 0.0f}; + std::vector fp16_output(output.size()); + onnxruntime::test::ConvertFloatToMLFloat16(output.data(), fp16_output.data(), static_cast(output.size())); + + test.AddInput("dY", {2, 2}, fp16_update); + + std::vector data_shape = {3, 3}; + test.AddInput("data_shape", {2}, data_shape); + + test.AddInput("indices", {2, 2}, + {1, 1, + 1, 1}, + true); + + test.AddOutput("dX", {3, 3}, fp16_output); + + test.Run(); +} + +} // namespace test +} // namespace cuda +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/communication/recv.h b/orttraining/orttraining/training_ops/cuda/communication/recv.h index abfc8ca03a..0d1a812038 100644 --- a/orttraining/orttraining/training_ops/cuda/communication/recv.h +++ b/orttraining/orttraining/training_ops/cuda/communication/recv.h @@ -6,7 +6,6 @@ #pragma once #include "core/common/common.h" #include "core/providers/cuda/cuda_common.h" -#include "core/providers/cuda/cudnn_common.h" namespace onnxruntime { namespace cuda { diff --git a/orttraining/orttraining/training_ops/cuda/communication/send.h b/orttraining/orttraining/training_ops/cuda/communication/send.h index 3350c519ed..878fee48d7 100644 --- a/orttraining/orttraining/training_ops/cuda/communication/send.h +++ b/orttraining/orttraining/training_ops/cuda/communication/send.h @@ -6,7 +6,7 @@ #pragma once #include "core/common/common.h" #include "core/providers/cuda/cuda_common.h" -#include "core/providers/cuda/cudnn_common.h" + namespace onnxruntime { namespace cuda { diff --git a/orttraining/orttraining/training_ops/cuda/math/isfinite.h b/orttraining/orttraining/training_ops/cuda/math/isfinite.h index ea11c1d43b..c5073f81cf 100644 --- a/orttraining/orttraining/training_ops/cuda/math/isfinite.h +++ b/orttraining/orttraining/training_ops/cuda/math/isfinite.h @@ -4,7 +4,7 @@ #pragma once #include "core/common/common.h" #include "core/framework/op_kernel.h" -#include "core/providers/cuda/cudnn_common.h" +#include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/multi_tensor/common.cuh" constexpr int PARALLEL_LOADS = 4; diff --git a/orttraining/orttraining/training_ops/cuda/math/mixed_precision_scale.h b/orttraining/orttraining/training_ops/cuda/math/mixed_precision_scale.h index 64e8d05224..7e92dc1b17 100644 --- a/orttraining/orttraining/training_ops/cuda/math/mixed_precision_scale.h +++ b/orttraining/orttraining/training_ops/cuda/math/mixed_precision_scale.h @@ -4,7 +4,7 @@ #pragma once #include "core/common/common.h" #include "core/framework/op_kernel.h" -#include "core/providers/cuda/cudnn_common.h" +#include "core/providers/cuda/cuda_common.h" namespace onnxruntime { namespace cuda { diff --git a/orttraining/orttraining/training_ops/cuda/nn/layer_norm.h b/orttraining/orttraining/training_ops/cuda/nn/layer_norm.h index c1cc447ace..fd21b09ba4 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/layer_norm.h +++ b/orttraining/orttraining/training_ops/cuda/nn/layer_norm.h @@ -1,7 +1,6 @@ #pragma once #include "core/common/common.h" #include "core/providers/cuda/cuda_common.h" -#include "core/providers/cuda/cudnn_common.h" namespace onnxruntime { namespace cuda { diff --git a/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu b/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu index e8ba8a0e6c..7f85c5676f 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu @@ -190,8 +190,8 @@ __device__ void cuWelfordMuSigma2( for (; l + 7 < n2; l += 8 * numx) { for (int k = 0; k < 8; k += 2) { float2 curr = __half22float2(*((__half2*)(lvals + l + k))); - cuWelfordOnlineSum(curr.x, mu, sigma2, count); - cuWelfordOnlineSum(curr.y, mu, sigma2, count); + cuWelfordOnlineSum(static_cast(curr.x), mu, sigma2, count); + cuWelfordOnlineSum(static_cast(curr.y), mu, sigma2, count); } } for (; l < n2; ++l) { @@ -308,7 +308,7 @@ __global__ void cuApplyLayerNorm( // 1) blockDim.x == GPU_WARP_SIZE // 2) Tensors are contiguous // - for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { + for (int i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { SharedMemory shared; U* buf = shared.getPointer(); U mu, sigma2; @@ -576,7 +576,7 @@ __global__ void cuComputeGradInput( const U* __restrict__ invvar, const T* gamma, T* grad_input) { - for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { + for (int i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { U sum_loss1 = U(0); U sum_loss2 = U(0); const U c_mean = mean[i1]; diff --git a/orttraining/orttraining/training_ops/cuda/optimizer/adam.h b/orttraining/orttraining/training_ops/cuda/optimizer/adam.h index fcfa617dbc..a35625885b 100644 --- a/orttraining/orttraining/training_ops/cuda/optimizer/adam.h +++ b/orttraining/orttraining/training_ops/cuda/optimizer/adam.h @@ -4,7 +4,6 @@ #pragma once #include "core/common/common.h" #include "core/providers/cuda/cuda_common.h" -#include "core/providers/cuda/cudnn_common.h" namespace onnxruntime { namespace cuda { diff --git a/orttraining/orttraining/training_ops/cuda/optimizer/gradient_control.h b/orttraining/orttraining/training_ops/cuda/optimizer/gradient_control.h index bd4d2fd3de..bf8c12d51a 100644 --- a/orttraining/orttraining/training_ops/cuda/optimizer/gradient_control.h +++ b/orttraining/orttraining/training_ops/cuda/optimizer/gradient_control.h @@ -4,7 +4,6 @@ #pragma once #include "core/common/common.h" #include "core/providers/cuda/cuda_common.h" -#include "core/providers/cuda/cudnn_common.h" namespace onnxruntime { namespace cuda { diff --git a/orttraining/orttraining/training_ops/cuda/optimizer/lamb.h b/orttraining/orttraining/training_ops/cuda/optimizer/lamb.h index bbb34ba208..55a828a94f 100644 --- a/orttraining/orttraining/training_ops/cuda/optimizer/lamb.h +++ b/orttraining/orttraining/training_ops/cuda/optimizer/lamb.h @@ -4,7 +4,6 @@ #pragma once #include "core/common/common.h" #include "core/providers/cuda/cuda_common.h" -#include "core/providers/cuda/cudnn_common.h" #include "core/providers/cuda/multi_tensor/common.cuh" namespace onnxruntime { diff --git a/orttraining/orttraining/training_ops/cuda/optimizer/sg.h b/orttraining/orttraining/training_ops/cuda/optimizer/sg.h index 91ffbbbab9..a58d98e1f9 100644 --- a/orttraining/orttraining/training_ops/cuda/optimizer/sg.h +++ b/orttraining/orttraining/training_ops/cuda/optimizer/sg.h @@ -4,7 +4,6 @@ #pragma once #include "core/common/common.h" #include "core/providers/cuda/cuda_common.h" -#include "core/providers/cuda/cudnn_common.h" namespace onnxruntime { namespace cuda { diff --git a/orttraining/orttraining/training_ops/cuda/tensor/gather_elements_grad.cc b/orttraining/orttraining/training_ops/cuda/tensor/gather_elements_grad.cc new file mode 100644 index 0000000000..4c49aad6d2 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/gather_elements_grad.cc @@ -0,0 +1,137 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/training_ops/cuda/tensor/gather_elements_grad.h" +#include "orttraining/training_ops/cuda/tensor/gather_elements_grad_impl.h" +#include "core/providers/cpu/tensor/utils.h" +#include "core/providers/common.h" + +namespace onnxruntime { +namespace cuda { + +ONNX_OPERATOR_KERNEL_EX( + GatherElementsGrad, + kMSDomain, + 1, + kCudaExecutionProvider, + KernelDefBuilder() + .InputMemoryType(1) // 'GatherElements' data shape needs to be on CPU + .TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()) + .TypeConstraint("I", DataTypeImpl::GetTensorType()) + .TypeConstraint("Tind", std::vector{DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + GatherElementsGrad); + +template +struct GatherElementsGrad::ComputeImpl { + Status operator()(const Tensor* dY, + const Tensor* indices_tensor, + Tensor* dX, + const int rank, + TArray& buffer_output_dims, + TArray& buffer_input_strides, + const int64_t indices_size, + TArray& buffer_indices_dims, + TArray& fdm_indices_strides, + const int axis) const { + T* output_data = dX->template MutableData(); + const T* update_data = dY->template Data(); + typedef typename ToCudaType::MappedType CudaT; + + MLDataType Tin_type = indices_tensor->DataType(); + if (utils::IsPrimitiveDataType(Tin_type)) { + const int32_t* indices_data = indices_tensor->template Data(); + return GatherElementsGradImpl( + rank, + buffer_output_dims, + buffer_input_strides, + indices_data, + indices_size, + buffer_indices_dims, + fdm_indices_strides, + reinterpret_cast(update_data), + axis, + reinterpret_cast(output_data)); + } else if (utils::IsPrimitiveDataType(Tin_type)) { + const int64_t* indices_data = indices_tensor->template Data(); + return GatherElementsGradImpl( + rank, + buffer_output_dims, + buffer_input_strides, + indices_data, + indices_size, + buffer_indices_dims, + fdm_indices_strides, + reinterpret_cast(update_data), + axis, + reinterpret_cast(output_data)); + } + + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Type for Tin is not supported yet in ScatterElements."); + } +}; + +Status GatherElementsGrad::ComputeInternal(OpKernelContext* context) const { + const auto* dY = context->Input(0); + const Tensor* shape = context->Input(1); + const TensorShape data_shape(shape->template Data(), shape->Shape().Size()); + + const int axis = static_cast(HandleNegativeAxis(axis_, data_shape.NumDimensions())); + + const auto* indices_tensor = context->Input(2); + + const auto& indices_dims = indices_tensor->Shape().GetDims(); + const int64_t indices_size = indices_tensor->Shape().Size(); + const auto& dY_dims = dY->Shape().GetDims(); + if (indices_dims.size() != dY_dims.size()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Indices and dY must have the same rank"); + } + + for (size_t i = 0; i < indices_dims.size(); ++i) { + if (indices_dims[i] != dY_dims[i]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Indices vs dY dimensions differs at position=", i, + " ", indices_dims[i], " vs ", dY_dims[i]); + } + } + + // According to the spec the rank of ind/upd shall be the same as output(data) + // and we also want to make sure that the dimensions of the of the ind/upd do not + // exceed that of the output + const auto& output_dims = data_shape.GetDims(); + if (output_dims.size() != indices_dims.size()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Indices must have the same rank as Output. Indices rank=", + indices_dims.size(), ". Output rank=", output_dims.size()); + } + + for (size_t i = 0; i < output_dims.size(); ++i) { + if (output_dims[i] < indices_dims[i]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Indices dim=", indices_dims[i], " at pos=", i, + " is greater than Output dim=", output_dims[i]); + } + } + + int rank = static_cast(output_dims.size()); + Tensor* dX = context->Output(0, data_shape); + CUDA_RETURN_IF_ERROR(cudaMemset(dX->MutableDataRaw(), 0, dX->SizeInBytes())); + + TArray buffer_output_dims(output_dims); + TensorPitches input_strides(output_dims); + TArray buffer_input_strides(input_strides); + + TArray buffer_indices_dims(indices_dims); + TArray fdm_indices_strides(rank); + TensorPitches indices_strides(indices_dims); + for (auto i = 0; i < rank; i++) { + fdm_indices_strides[i] = fast_divmod(static_cast(indices_strides[i])); + } + + utils::MLTypeCallDispatcherRet + t_disp(dY->GetElementType()); + return t_disp.Invoke(dY, indices_tensor, dX, rank, + buffer_output_dims, buffer_input_strides, indices_size, + buffer_indices_dims, fdm_indices_strides, axis); +} + +} // namespace cuda +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/tensor/gather_elements_grad.h b/orttraining/orttraining/training_ops/cuda/tensor/gather_elements_grad.h new file mode 100644 index 0000000000..82291b3968 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/gather_elements_grad.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/providers/cuda/cuda_common.h" + +namespace onnxruntime { +namespace cuda { + +class GatherElementsGrad final : public CudaKernel { + public: + GatherElementsGrad(const OpKernelInfo& info) : CudaKernel(info) { + info.GetAttrOrDefault("axis", &axis_, static_cast(0)); + } + ~GatherElementsGrad() = default; + Status ComputeInternal(OpKernelContext* context) const override; + + private: + template + struct ComputeImpl; + + int64_t axis_; +}; + +} // namespace cuda +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/tensor/gather_elements_grad_impl.h b/orttraining/orttraining/training_ops/cuda/tensor/gather_elements_grad_impl.h new file mode 100755 index 0000000000..713fe3f7bc --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/gather_elements_grad_impl.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace cuda { + +template +Status GatherElementsGradImpl( + const int rank, + TArray& buffer_input_dims, + TArray& buffer_input_strides, + const Tin* indices_data, + const int64_t indices_size, + TArray& buffer_indices_dims, + TArray& indices_strides, + const T* updates, + const int axis, + T* output_data); + +} // namespace cuda +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda_training_kernels.cc index 16751ea2e0..76c04462ae 100644 --- a/orttraining/orttraining/training_ops/cuda_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda_training_kernels.cc @@ -121,6 +121,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_float, LayerNormalizationGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_float, LayerNormalizationGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, SliceGrad); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GatherElementsGrad); #ifdef USE_HOROVOD class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, HorovodAllReduce); @@ -252,6 +253,7 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, #ifdef USE_HOROVOD BuildKernelCreateInfo, diff --git a/orttraining/pytorch_frontend_examples/mnist_training.py b/orttraining/pytorch_frontend_examples/mnist_training.py index e7f11e2725..ed73a132e4 100644 --- a/orttraining/pytorch_frontend_examples/mnist_training.py +++ b/orttraining/pytorch_frontend_examples/mnist_training.py @@ -17,6 +17,7 @@ import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms import numpy as np +import os # TODO: remove after ready for CV # import sys @@ -27,13 +28,15 @@ import numpy as np # from ort_trainer import IODescription, ModelDescription, ORTTrainer, ORTModel from onnxruntime.capi.ort_trainer import IODescription, ModelDescription, ORTTrainer, ORTModel +from mpi4py import MPI +from onnxruntime.capi._pybind_state import set_cuda_device_id class NeuralNet(nn.Module): def __init__(self, input_size, hidden_size, num_classes): super(NeuralNet, self).__init__() - self.fc1 = nn.Linear(input_size, hidden_size) + self.fc1 = nn.Linear(input_size, hidden_size) self.relu = nn.ReLU() - self.fc2 = nn.Linear(hidden_size, num_classes) + self.fc2 = nn.Linear(hidden_size, num_classes) def forward(self, x): out = self.fc1(x) @@ -79,18 +82,19 @@ def test_with_model(args, model, device, test_loader, optimizer, epoch): test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))) -def train_with_trainer(args, trainer, device, train_loader, epoch): +def train_with_trainer(args, trainer, device, train_loader, epoch): for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) data = data.reshape(data.shape[0], -1) learning_rate = torch.tensor([args.lr]) - loss = trainer.train_step((data, target, learning_rate)) + loss = trainer.train_step(data, target, learning_rate) + # Since the output corresponds to [loss_desc, probability_desc], the first value is taken as loss. if batch_idx % args.log_interval == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), - 100. * batch_idx / len(train_loader), loss.item())) + 100. * batch_idx / len(train_loader), loss[0])) # TODO: comple this once ORT training can do evaluation. def test_with_trainer(args, trainer, device, test_loader): @@ -152,8 +156,6 @@ def main(): torch.manual_seed(args.seed) - device = torch.device("cuda" if use_cuda else "cpu") - kwargs = {'num_workers': 0, 'pin_memory': True} train_loader = torch.utils.data.DataLoader( datasets.MNIST('../data', train=True, download=True, @@ -167,6 +169,19 @@ def main(): transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])), batch_size=args.test_batch_size, shuffle=True, **kwargs) + + comm = MPI.COMM_WORLD + args.local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) if ('OMPI_COMM_WORLD_LOCAL_RANK' in os.environ) else 0 + args.world_rank = int(os.environ['OMPI_COMM_WORLD_RANK']) if ('OMPI_COMM_WORLD_RANK' in os.environ) else 0 + args.world_size=comm.Get_size() + torch.cuda.set_device(args.local_rank) + if use_cuda: + device = torch.device("cuda", args.local_rank) + else: + device = torch.device("cpu") + args.n_gpu = 1 + set_cuda_device_id(args.local_rank) + input_size = 784 hidden_size = 500 num_classes = 10 @@ -175,14 +190,17 @@ def main(): model_desc = mnist_model_description() if args.use_ort_trainer: # use log_interval as gradient accumulate steps - trainer = ORTTrainer(model, my_loss, model_desc, "SGDOptimizer", None, IODescription('Learning_Rate', [1,], torch.float32), device) + trainer = ORTTrainer(model, my_loss, model_desc, "LambOptimizer", None, IODescription('Learning_Rate', [1,], torch.float32), device, 1, None, + args.world_rank, args.world_size, use_mixed_precision=False, allreduce_post_accumulation = True) + print('\nBuild ort model done.') for epoch in range(1, args.epochs + 1): train_with_trainer(args, trainer, device, train_loader, epoch) import pdb test_with_trainer(args, trainer, device, test_loader) else: - model = ORTModel(model, my_loss, model_desc, device) + model = ORTModel(model, my_loss, model_desc, device, None, args.world_rank, args.world_size) + print('\nBuild ort model done.') optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) diff --git a/orttraining/tools/scripts/pipeline_model_split.py b/orttraining/tools/scripts/pipeline_model_split.py index d0385edeac..008e626e32 100644 --- a/orttraining/tools/scripts/pipeline_model_split.py +++ b/orttraining/tools/scripts/pipeline_model_split.py @@ -4,7 +4,6 @@ import onnx from onnx import helper from onnx import TensorProto from onnx import OperatorSetIdProto - # Edge that needs to be cut for the split. # If the edge is feeding into more than one nodes, and not all the nodes belong to the same cut, # specify those consuming nodes that need to be cut @@ -29,31 +28,6 @@ def split_graph(model, split_edge_groups): new_send_nodes = [] new_recv_nodes = [] - # Add wait for initial inputs. This needs to be done first before new inputs - # are introduced from split - initializer_lists = [a.name for a in model.graph.initializer] - input_tensors = [ - value.name for value in model.graph.input if value.name not in initializer_lists] - - input_wait_signal = model.graph.input.add() - input_wait_signal.CopyFrom(helper.make_tensor_value_info( - 'input_wait_signal', onnx.TensorProto.INT64, None)) - - input_wait = model.graph.node.add() - input_wait.CopyFrom(helper.make_node( - 'WaitEvent', - inputs=['input_wait_signal'], - outputs=[], - domain=ms_domain)) - - for i in input_tensors: - for node in model.graph.node: - for j in range(len(node.input)): - if node.input[j] == i: - node.input[j] = i + '_sync' - - input_wait.input.extend(input_tensors) - input_wait.output.extend([i + '_sync' for i in input_tensors]) for cut_index in range(len(split_edge_groups)): edgeIds = split_edge_groups[cut_index] @@ -62,7 +36,7 @@ def split_graph(model, split_edge_groups): upstream_nodes = [] upstream_nodes_output_index = [] output_shapes = [] - + element_types = [] for id in edgeIds: for node in model.graph.node: if len(node.output) >= 1: @@ -70,25 +44,43 @@ def split_graph(model, split_edge_groups): if j == id: upstream_nodes.append(node) upstream_nodes_output_index.append(i) - for info in model.graph.value_info: - if info.name == id: - output_shapes.append(info.type) - - record_signal = model.graph.input.add() - record_signal.CopyFrom(helper.make_tensor_value_info( - 'record_input_signal' + str(cut_index), onnx.TensorProto.INT64, None)) - - wait_signal = model.graph.input.add() - wait_signal.CopyFrom(helper.make_tensor_value_info( - 'wait_input_signal' + str(cut_index), onnx.TensorProto.INT64, None)) + # assuming all tensors are of type float + element_types.append(1) + for info in model.graph.value_info: + if info.name == id: + output_shapes.append(info.type) + send_input_signal_name = 'send_input_signal' + str(cut_index) send_signal = model.graph.input.add() send_signal.CopyFrom(helper.make_tensor_value_info( - 'send_input_signal' + str(cut_index), onnx.TensorProto.BOOL, None)) + send_input_signal_name, onnx.TensorProto.BOOL, None)) + send_signal = helper.make_tensor( + send_input_signal_name, TensorProto.BOOL, (), (True,)) + model.graph.initializer.extend([send_signal]) + recv_input_signal_name = 'recv_input_signal' + str(cut_index) recv_signal = model.graph.input.add() recv_signal.CopyFrom(helper.make_tensor_value_info( - 'recv_input_signal' + str(cut_index), onnx.TensorProto.BOOL, None)) + recv_input_signal_name, onnx.TensorProto.BOOL, None)) + recv_signal = helper.make_tensor( + recv_input_signal_name, TensorProto.BOOL, (), (True,)) + model.graph.initializer.extend([recv_signal]) + + send_dst_rank_name = 'send_dst_rank' + str(cut_index) + send_dst_rank = model.graph.input.add() + send_dst_rank.CopyFrom(helper.make_tensor_value_info( + send_dst_rank_name, onnx.TensorProto.INT64, None)) + send_dst_rank = helper.make_tensor( + send_dst_rank_name, TensorProto.INT64, (), (cut_index + 1,)) + model.graph.initializer.extend([send_dst_rank]) + + recv_src_rank_name = 'recv_src_rank' + str(cut_index) + recv_src_rank = model.graph.input.add() + recv_src_rank.CopyFrom(helper.make_tensor_value_info( + recv_src_rank_name, onnx.TensorProto.INT64, None)) + recv_src_rank = helper.make_tensor( + recv_src_rank_name, TensorProto.INT64, (), (cut_index,)) + model.graph.initializer.extend([recv_src_rank]) # output signal from send after cut send_output_signal = model.graph.output.add() @@ -103,41 +95,23 @@ def split_graph(model, split_edge_groups): new_send = model.graph.node.add() new_send.CopyFrom(helper.make_node( 'Send', - inputs=['send_input_signal' + str(cut_index)], + inputs=[send_input_signal_name, send_dst_rank_name], outputs=['send_output_signal' + str(cut_index)], tag=0, - src=cut_index, - dst=cut_index + 1, domain=ms_domain, - element_type=7, # assuming all tensors are of type float + element_types=element_types, name='send')) new_receive = model.graph.node.add() new_receive.CopyFrom(helper.make_node( 'Recv', - inputs=['recv_input_signal' + str(cut_index)], + inputs=[recv_input_signal_name, recv_src_rank_name], outputs=['receive_output_signal' + str(cut_index)], - tag=1, - src=cut_index, - dst=cut_index + 1, + tag=0, domain=ms_domain, - element_type=7, # assuming all tensors are of type float + element_types=element_types, name='receive')) - new_wait = model.graph.node.add() - new_wait.CopyFrom(helper.make_node( - 'WaitEvent', - inputs=['wait_input_signal' + str(cut_index)], - outputs=[], - domain=ms_domain)) - - new_record = model.graph.node.add() - new_record.CopyFrom(helper.make_node( - 'RecordEvent', - inputs=['record_input_signal' + str(cut_index)], - outputs=[], - domain=ms_domain)) - for i in range(len(upstream_nodes)): n = upstream_nodes[i] idx = upstream_nodes_output_index[i] @@ -155,24 +129,16 @@ def split_graph(model, split_edge_groups): '_recv' + str(cut_index) add_expand_type(model, new_receive_output_name, output_type) - new_wait_output_name = output_edge_name + '_wait' + str(cut_index) - add_expand_type(model, new_wait_output_name, output_type) - # the order of data flow is: node-output -> record -> send -> recv -> wait -> node-input - new_record.input.extend([output_edge_name]) - new_record.output.extend([new_send_input_name]) - new_send.input.extend([new_send_input_name]) + new_send.input.extend([output_edge_name]) new_receive.output.extend([new_receive_output_name]) - new_wait.input.extend([new_receive_output_name]) - new_wait.output.extend([new_wait_output_name]) - for output_node in output_nodes: for i in range(len(output_node.input)): for edgeId in edgeIds: if output_node.input[i] == edgeId: - output_node.input[i] = new_wait_output_name + output_node.input[i] = new_receive_output_name new_send_nodes.append(new_send) new_recv_nodes.append(new_receive) @@ -236,9 +202,50 @@ def add_identity(model, cuttingEdge, newEdgeIdName): if output_nodes[i].input[j] == edgeId: output_nodes[i].input[j] = newEdgeIdName - return newEdgeIdName + return new_identity +def insert_identity(model, all_cut_inputs): + count = 0 + updated_edges = {} + new_added_identity = [] + split_edge_groups = [] + need_shape_inference = False + # Sweep the cut edge to see if there are edges feeding into nodes from two sub-graphs. If so, + # insert identity node after those edges with a new ID to distinguish the rest. + for cut_input in all_cut_inputs: + split_edges = [] + for i in cut_input: + if i.consumingNodes: + # if this edge has previously been modified, update its edgeId before inserting new identity + if i.edgeId in updated_edges: + i.edgeId = updated_edges[i.edgeId] + + new_edge_name = 'identity_output_' + str(count) + new_added_identity.append( + add_identity(model, i, new_edge_name)) + count += 1 + split_edges.append(new_edge_name) + updated_edges[i.edgeId] = new_edge_name + need_shape_inference = True + else: + split_edges.append(i.edgeId) + split_edge_groups.append(split_edges) + return split_edge_groups, new_added_identity, need_shape_inference + +# after the graph is split, remove the added identity node because identity op is not registered in gradient builder. + + +def remove_identity(model, new_added_identity): + for node in new_added_identity: + assert node.op_type == 'Identity' + output_nodes = [ + n for n in model.graph.node if node.output[0] in n.input] + for output_node in output_nodes: + for i in range(len(output_node.input)): + if output_node.input[i] == node.output[0]: + output_node.input[i] = node.input[0] + def find_all_connected_nodes(model, node): nodes0, inputs = find_all_input_nodes(model, node) nodes1, outputs = find_all_output_nodes(model, node) @@ -251,15 +258,36 @@ def get_index(node_list, node): found = [i for i, n in enumerate(node_list) if n == node] return found[0] if found else None +def get_identity_index_for_deleting(node_list, node): + for i, n in enumerate(node_list): + # The node's input name has been changed during send/recv insertion, + # but it is sufficient to just compare the type and outputs. + if (n.op_type == 'Identity' and n.output == node.output): + return i + return None + # traverse the graph, group connected nodes and generate subgraph -def generate_subgraph(model, start_nodes): +def generate_subgraph(model, start_nodes, identity_node_list): subgraphs = [] main_graph = onnx.ModelProto() main_graph.CopyFrom(model) + # remove added identity node before copy to subgraph + identity_node_index = [] + for n in identity_node_list: + identity_node_index.append(get_identity_index_for_deleting(main_graph.graph.node, n)) + identity_node_index.sort(reverse=True) + + for i in reversed(range(len(main_graph.graph.node))): + try: + if i in identity_node_index: + del main_graph.graph.node[i] + except: + print("error deleting identity node", i) + all_visited_nodes = [] model_count = len(start_nodes) for start in reversed(start_nodes): @@ -362,29 +390,8 @@ def main(): output_model_names = [os.path.splitext(input_model_name)[0] + '_' + str(i) + '.onnx' for i in range(stage_count)] - split_edge_groups = [] - count = 0 - updated_edges = {} - need_shape_inference = False - # Sweep the cut edge to see if there are edges feeding into nodes from two sub-graphs. If so, - # insert identity node after those edges with a new ID to distinguish the rest. - for cut_input in all_cut_inputs: - split_edges = [] - for i in cut_input: - if i.consumingNodes: - # if this edge has previously been modified, update its edgeId before inserting new identity - if i.edgeId in updated_edges: - i.edgeId = updated_edges[i.edgeId] + split_edge_groups, new_identity, need_shape_inference = insert_identity(model, all_cut_inputs) - new_edge_name = 'identity_output_' + str(count) - add_identity(model, i, new_edge_name) - count += 1 - split_edges.append(new_edge_name) - updated_edges[i.edgeId] = new_edge_name - need_shape_inference = True - else: - split_edges.append(i.edgeId) - split_edge_groups.append(split_edges) # new edge is being added, need to re-inference shape if need_shape_inference: @@ -392,11 +399,13 @@ def main(): # after all need-to-be-cut edges identified, split the graph new_sends, new_receives = split_graph(model, split_edge_groups) - sub_graphs = generate_subgraph(model, new_receives) + remove_identity(model, new_identity) + sub_graphs = generate_subgraph(model, new_receives, new_identity) for i in range(stage_count): sub_graphs[i] = onnx.shape_inference.infer_shapes(sub_graphs[i]) onnx.save(sub_graphs[i], output_model_names[i]) + print("save to file: ", output_model_names[i]) if __name__ == "__main__":