mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-20 02:07:56 +00:00
Merge remote-tracking branch 'origin/ort_training' into edgchen1/merge_from_ort_training
This commit is contained in:
commit
4416d41874
58 changed files with 1614 additions and 224 deletions
|
|
@ -1110,6 +1110,14 @@ class Graph {
|
|||
// Graph value_info.
|
||||
std::vector<const NodeArg*> value_info_;
|
||||
|
||||
// Strings which have been used as node names.
|
||||
// New node name should not conflict with this set.
|
||||
std::unordered_set<std::string> generated_node_names_;
|
||||
|
||||
// Strings which have been used as node_arg names.
|
||||
// New node_arg name should not conflict this this set.
|
||||
std::unordered_set<std::string> generated_node_arg_names_;
|
||||
|
||||
// All node args owned by <*this> graph. Key is node arg name.
|
||||
std::unordered_map<std::string, std::unique_ptr<NodeArg>> node_args_;
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
|
||||
#include "attention.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/providers/cuda/cudnn_common.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cuda/shared_inc/fpgeneric.h"
|
||||
#include "attention_impl.h"
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/providers/cuda/cudnn_common.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "contrib_ops/cpu/bert/attention.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/cuda/cudnn_common.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "onnx/defs/tensor_proto_util.h"
|
||||
#include "contrib_ops/cpu/bert/embed_layer_norm_helper.h"
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/cuda/cudnn_common.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "fast_gelu.h"
|
||||
#include "fast_gelu_impl.h"
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/cuda/cudnn_common.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "onnx/defs/tensor_proto_util.h"
|
||||
#include "skip_layer_norm.h"
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
#include "layer_norm_impl.h"
|
||||
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/cuda/cudnn_common.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
|
|||
|
|
@ -107,13 +107,14 @@ __device__ void cuWelfordMuSigma2(
|
|||
cuWelfordOnlineSum<U>(curr, mu, sigma2, count);
|
||||
}
|
||||
// intra-warp reductions
|
||||
for (int l = 0; l <= 4; ++l) {
|
||||
int srcLaneB = (threadIdx.x + (1 << l)) & 31;
|
||||
U muB = WARP_SHFL(mu, srcLaneB);
|
||||
U countB = WARP_SHFL(count, srcLaneB);
|
||||
U sigma2B = WARP_SHFL(sigma2, srcLaneB);
|
||||
#pragma unroll
|
||||
for (int stride = GPU_WARP_SIZE / 2; stride > 0; stride /= 2) {
|
||||
U muB = WARP_SHFL_DOWN(mu, stride);
|
||||
U countB = WARP_SHFL_DOWN(count, stride);
|
||||
U sigma2B = WARP_SHFL_DOWN(sigma2, stride);
|
||||
cuChanOnlineSum<U>(muB, sigma2B, countB, mu, sigma2, count);
|
||||
}
|
||||
|
||||
// threadIdx.x == 0 has correct values for each warp
|
||||
// inter-warp reductions
|
||||
if (blockDim.y > 1) {
|
||||
|
|
@ -192,8 +193,8 @@ __device__ void cuWelfordMuSigma2(
|
|||
for (; l + 7 < n2; l += 8 * numx) {
|
||||
for (int k = 0; k < 8; k += 2) {
|
||||
float2 curr = __half22float2(*((__half2*)(lvals + l + k)));
|
||||
cuWelfordOnlineSum(curr.x, mu, sigma2, count);
|
||||
cuWelfordOnlineSum(curr.y, mu, sigma2, count);
|
||||
cuWelfordOnlineSum(static_cast<float>(curr.x), mu, sigma2, count);
|
||||
cuWelfordOnlineSum(static_cast<float>(curr.y), mu, sigma2, count);
|
||||
}
|
||||
}
|
||||
for (; l < n2; ++l) {
|
||||
|
|
@ -201,13 +202,14 @@ __device__ void cuWelfordMuSigma2(
|
|||
cuWelfordOnlineSum(curr, mu, sigma2, count);
|
||||
}
|
||||
// intra-warp reductions
|
||||
for (int l = 0; l <= 4; ++l) {
|
||||
int srcLaneB = (threadIdx.x + (1 << l)) & 31;
|
||||
float muB = WARP_SHFL(mu, srcLaneB);
|
||||
float countB = WARP_SHFL(count, srcLaneB);
|
||||
float sigma2B = WARP_SHFL(sigma2, srcLaneB);
|
||||
#pragma unroll
|
||||
for (int stride = GPU_WARP_SIZE / 2; stride > 0; stride /= 2) {
|
||||
float muB = WARP_SHFL_DOWN(mu, stride);
|
||||
float countB = WARP_SHFL_DOWN(count, stride);
|
||||
float sigma2B = WARP_SHFL_DOWN(sigma2, stride);
|
||||
cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count);
|
||||
}
|
||||
|
||||
// threadIdx.x == 0 has correct values for each warp
|
||||
// inter-warp reductions
|
||||
if (blockDim.y > 1) {
|
||||
|
|
@ -310,7 +312,7 @@ __global__ void cuApplyLayerNorm(
|
|||
// 1) blockDim.x == GPU_WARP_SIZE
|
||||
// 2) Tensors are contiguous
|
||||
//
|
||||
for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
|
||||
for (int i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
|
||||
SharedMemory<U> shared;
|
||||
U* buf = shared.getPointer();
|
||||
U mu, sigma2;
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/providers/cuda/cudnn_common.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
|
|||
|
|
@ -2376,16 +2376,47 @@ Node& Graph::AddNode(const NodeProto& node_proto,
|
|||
}
|
||||
|
||||
std::string Graph::GenerateNodeArgName(const std::string& base_name) {
|
||||
// Check if base_name has been used in as any of node_args_' names.
|
||||
bool found = node_args_.find(base_name) != node_args_.end();
|
||||
// Check if base_name has been generated by this function.
|
||||
// If not, add base_name into name set and return the base_name
|
||||
// as the generated name.
|
||||
found |= generated_node_arg_names_.find(base_name) != generated_node_arg_names_.end();
|
||||
if (!found) {
|
||||
generated_node_arg_names_.insert(base_name);
|
||||
return base_name;
|
||||
}
|
||||
|
||||
// base_name has been used by another node. Because two node_arg's cannot have
|
||||
// the sam name, we are going to generate another string.
|
||||
std::string new_name;
|
||||
do {
|
||||
std::ostringstream str;
|
||||
str << base_name << "_" << name_generator_++;
|
||||
new_name = str.str();
|
||||
} while (node_args_.find(new_name) != node_args_.end());
|
||||
// If node_args_ or generated_node_arg_names_ contains new_name, we go to the next iteration.
|
||||
} while (node_args_.find(new_name) != node_args_.end() ||
|
||||
generated_node_arg_names_.find(new_name) != generated_node_arg_names_.end());
|
||||
|
||||
// Now new_name is different than any of existing node_arg names.
|
||||
// Register new_name so that it won't be used again.
|
||||
generated_node_arg_names_.insert(new_name);
|
||||
|
||||
return new_name;
|
||||
}
|
||||
|
||||
std::string Graph::GenerateNodeName(const std::string& base_name) {
|
||||
// Check if base_name has been used in as any of nodes_' names.
|
||||
bool found = std::find_if(nodes_.cbegin(), nodes_.cend(), [&base_name](const std::unique_ptr<Node>& n) {
|
||||
return (n != nullptr) && (n->Name() == base_name);}) != nodes_.end();
|
||||
// Check if base_name has been generated by this function.
|
||||
found |= generated_node_names_.find(base_name) != generated_node_names_.end();
|
||||
if (!found) {
|
||||
// Register base_name so that it won't be used again.
|
||||
generated_node_names_.insert(base_name);
|
||||
return base_name;
|
||||
}
|
||||
|
||||
std::string new_name;
|
||||
bool keep_going = true;
|
||||
|
||||
|
|
@ -2394,11 +2425,18 @@ std::string Graph::GenerateNodeName(const std::string& base_name) {
|
|||
str << base_name << "_" << name_generator_++;
|
||||
new_name = str.str();
|
||||
|
||||
// Check if new_name has been used in as any of nodes_' names.
|
||||
keep_going = std::find_if(nodes_.cbegin(), nodes_.cend(), [&new_name](const std::unique_ptr<Node>& n) {
|
||||
return (n != nullptr) && (n->Name() == new_name);
|
||||
}) != nodes_.end();
|
||||
// Check if new_name has been generated by this function.
|
||||
keep_going |= generated_node_names_.find(new_name) != generated_node_names_.end();
|
||||
} while (keep_going);
|
||||
|
||||
// Now new_name is different than any of existing node names.
|
||||
// Register new_name so that it won't be used again.
|
||||
generated_node_names_.insert(new_name);
|
||||
|
||||
return new_name;
|
||||
}
|
||||
|
||||
|
|
|
|||
74
onnxruntime/core/optimizer/expand_elimination.cc
Normal file
74
onnxruntime/core/optimizer/expand_elimination.cc
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/optimizer/initializer.h"
|
||||
#include "core/optimizer/rewrite_rule.h"
|
||||
#include "core/optimizer/expand_elimination.h"
|
||||
#include "core/graph/graph.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
Status ExpandElimination::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
|
||||
if (graph_utils::RemoveNode(graph, node)) {
|
||||
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool ExpandElimination::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
|
||||
if (!graph_utils::CanRemoveNode(graph, node, logger)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 1. Check if has input shape.
|
||||
const auto* input_shape = node.InputDefs()[0]->Shape();
|
||||
if (input_shape == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 2. Get target shape if it's constant.
|
||||
const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name());
|
||||
if (tensor_proto == nullptr || tensor_proto->dims_size() != 1 || tensor_proto->dims(0) <= 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto initializer = onnxruntime::make_unique<Initializer>(*tensor_proto, graph.ModelPath());
|
||||
if (initializer->data_type() != ONNX_NAMESPACE::TensorProto_DataType_INT64) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int64_t* target_shapes = initializer->data<int64_t>();
|
||||
|
||||
// Check the dimensions starting at the trailing dimension.
|
||||
int i = input_shape->dim_size() - 1;
|
||||
int j = static_cast<int>(tensor_proto->dims(0) - 1);
|
||||
|
||||
// The Expand produces same input tensor only when target dimension size is not greater than input's.
|
||||
if (i < j) {
|
||||
return false;
|
||||
}
|
||||
|
||||
while (i >= 0 && j >= 0) {
|
||||
auto dim = input_shape->dim(i);
|
||||
if (utils::HasDimValue(dim)) {
|
||||
auto dim_value = dim.dim_value();
|
||||
if (dim_value != target_shapes[j] && target_shapes[j] > 1) {
|
||||
return false;
|
||||
}
|
||||
} else if (target_shapes[j] > 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
--i;
|
||||
--j;
|
||||
}
|
||||
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
31
onnxruntime/core/optimizer/expand_elimination.h
Normal file
31
onnxruntime/core/optimizer/expand_elimination.h
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/optimizer/rewrite_rule.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
/**
|
||||
@Class ExpandElimination
|
||||
|
||||
Rewrite rule that eliminates Expand nodes if the node generate the same tensor as the input tensor.
|
||||
|
||||
It is attempted to be triggered only on nodes with op type "Expand".
|
||||
*/
|
||||
class ExpandElimination : public RewriteRule {
|
||||
public:
|
||||
ExpandElimination() noexcept : RewriteRule("ExpandElimination") {}
|
||||
|
||||
std::vector<std::string> TargetOpTypes() const noexcept override {
|
||||
return {"Expand"};
|
||||
}
|
||||
|
||||
private:
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
|
||||
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -27,6 +27,7 @@
|
|||
#include "core/optimizer/embed_layer_norm_fusion.h"
|
||||
#include "core/optimizer/reshape_fusion.h"
|
||||
#include "core/optimizer/attention_fusion.h"
|
||||
#include "core/optimizer/expand_elimination.h"
|
||||
#include "core/optimizer/cast_elimination.h"
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
|
||||
|
|
@ -47,6 +48,7 @@ std::vector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(TransformerLevel
|
|||
rules.push_back(onnxruntime::make_unique<EliminateSlice>());
|
||||
rules.push_back(onnxruntime::make_unique<UnsqueezeElimination>());
|
||||
rules.push_back(onnxruntime::make_unique<EliminateDropout>());
|
||||
rules.push_back(onnxruntime::make_unique<ExpandElimination>());
|
||||
rules.push_back(onnxruntime::make_unique<CastElimination>());
|
||||
rules.push_back(onnxruntime::make_unique<FuseReluClip>());
|
||||
rules.push_back(onnxruntime::make_unique<ShapeToInitializer>());
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
#include "gemm.h"
|
||||
#include "core/providers/cpu/math/gemm_helper.h"
|
||||
#include "core/providers/cuda/cudnn_common.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cuda/shared_inc/fpgeneric.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@
|
|||
// The code below is mostly copied from Pytorch PersistentSoftmax.cuh
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "gsl/gsl"
|
||||
#include "core/providers/cuda/cudnn_common.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ namespace cuda {
|
|||
|
||||
class Concat final : public CudaKernel, public ConcatBase {
|
||||
public:
|
||||
Concat(const OpKernelInfo& info) : ConcatBase(info), CudaKernel(info) {}
|
||||
Concat(const OpKernelInfo& info) : CudaKernel(info), ConcatBase(info) {}
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ namespace cuda {
|
|||
|
||||
class Gather final : public CudaKernel, public GatherBase {
|
||||
public:
|
||||
Gather(const OpKernelInfo& info) : GatherBase(info), CudaKernel(info) {}
|
||||
Gather(const OpKernelInfo& info) : CudaKernel(info), GatherBase(info) {}
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -1,13 +1,17 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cuda/atomic/common.cuh"
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
#include "scatter_elements_impl.h"
|
||||
#ifdef ENABLE_TRAINING
|
||||
#include "orttraining/training_ops/cuda/tensor/gather_elements_grad_impl.h"
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
template <typename T, typename Tin, bool OUTERAXIS>
|
||||
template <typename T, typename Tin, bool OUTERAXIS, typename FuncT>
|
||||
__global__ void _ScatterElementsKernel2D(
|
||||
const int max_dim, // max dim on the scattered axis
|
||||
const T* input_data,
|
||||
|
|
@ -16,7 +20,8 @@ __global__ void _ScatterElementsKernel2D(
|
|||
const fast_divmod indices_stride_row,
|
||||
const T* updates,
|
||||
const int64_t output_row_size,
|
||||
T* output_data) {
|
||||
T* output_data,
|
||||
const FuncT& func) {
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(indices_index, indices_size);
|
||||
|
||||
int row, col, data_idx;
|
||||
|
|
@ -29,12 +34,13 @@ __global__ void _ScatterElementsKernel2D(
|
|||
} else {
|
||||
data_idx = row * output_row_size + dim;
|
||||
}
|
||||
output_data[data_idx] = updates[indices_index];
|
||||
|
||||
func(output_data + data_idx, updates + indices_index);
|
||||
}
|
||||
// else invalid index
|
||||
}
|
||||
|
||||
template <typename T, typename Tin>
|
||||
template <typename T, typename Tin, typename FuncT>
|
||||
__global__ void _ScatterElementsKernel(
|
||||
const int rank,
|
||||
const T* input_data,
|
||||
|
|
@ -46,7 +52,8 @@ __global__ void _ScatterElementsKernel(
|
|||
const TArray<fast_divmod> indices_strides,
|
||||
const T* updates,
|
||||
const int axis,
|
||||
T* output_data) {
|
||||
T* output_data,
|
||||
const FuncT& func) {
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(indices_index, indices_size);
|
||||
int dim, remain = indices_index;
|
||||
size_t data_idx = 0;
|
||||
|
|
@ -61,7 +68,7 @@ __global__ void _ScatterElementsKernel(
|
|||
}
|
||||
data_idx += input_strides[i] * dim;
|
||||
}
|
||||
output_data[data_idx] = updates[indices_index];
|
||||
func(output_data + data_idx, updates + indices_index);
|
||||
}
|
||||
|
||||
// From the innermost axis (largest) check equality of dim value of input and indices.
|
||||
|
|
@ -135,7 +142,7 @@ static int CompactInputIndicesDims(
|
|||
return new_axis;
|
||||
}
|
||||
|
||||
template <typename T, typename Tin>
|
||||
template <typename T, typename Tin, typename FuncT>
|
||||
Status ScatterElementsImpl2D(
|
||||
const T* input_data,
|
||||
const std::vector<int64_t>& input_dims,
|
||||
|
|
@ -144,23 +151,70 @@ Status ScatterElementsImpl2D(
|
|||
const std::vector<int64_t>& indices_dims,
|
||||
const T* updates,
|
||||
const int axis,
|
||||
T* output_data) {
|
||||
T* output_data,
|
||||
const FuncT& func) {
|
||||
int blocksPerGrid = gsl::narrow_cast<int>(CeilDiv(indices_size, GridDim::maxThreadsPerBlock));
|
||||
fast_divmod indices_stride_row(indices_dims[1]);
|
||||
if (axis == 0) {
|
||||
_ScatterElementsKernel2D<T, Tin, true><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
_ScatterElementsKernel2D<T, Tin, true, FuncT><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
gsl::narrow_cast<int>(input_dims[0]), input_data,
|
||||
indices_data, indices_size, indices_stride_row,
|
||||
updates, input_dims[1], output_data);
|
||||
updates, input_dims[1], output_data, func);
|
||||
} else {
|
||||
_ScatterElementsKernel2D<T, Tin, false><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
_ScatterElementsKernel2D<T, Tin, false, FuncT><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
gsl::narrow_cast<int>(input_dims[1]), input_data,
|
||||
indices_data, indices_size, indices_stride_row,
|
||||
updates, input_dims[1], output_data);
|
||||
updates, input_dims[1], output_data, func);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T, typename Tin, typename FuncT>
|
||||
Status ScatterElementsImplInternal(
|
||||
const int rank,
|
||||
const T* input_data,
|
||||
const int64_t input_size,
|
||||
TArray<int64_t>& buffer_input_dims,
|
||||
TArray<int64_t>& buffer_input_strides,
|
||||
const Tin* indices_data,
|
||||
const int64_t indices_size,
|
||||
TArray<int64_t>& buffer_indices_dims,
|
||||
TArray<fast_divmod>& fdm_indices_strides,
|
||||
const T* updates,
|
||||
const int axis,
|
||||
T* output_data,
|
||||
const FuncT& func) {
|
||||
if (input_data != output_data) {
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_data, input_data, input_size * sizeof(T), cudaMemcpyDeviceToDevice, 0));
|
||||
}
|
||||
|
||||
if (indices_size > 0) {
|
||||
std::vector<int64_t> eff_input_dims;
|
||||
std::vector<int64_t> eff_indices_dims;
|
||||
int new_axis = CompactInputIndicesDims(
|
||||
rank, axis, buffer_input_dims.data_, buffer_indices_dims.data_, eff_input_dims, eff_indices_dims);
|
||||
if (eff_input_dims.size() == 2) {
|
||||
return ScatterElementsImpl2D(
|
||||
input_data, eff_input_dims, indices_data, indices_size, eff_indices_dims, updates, new_axis, output_data,
|
||||
func);
|
||||
}
|
||||
|
||||
int blocksPerGrid = gsl::narrow_cast<int>(CeilDiv(indices_size, GridDim::maxThreadsPerBlock));
|
||||
_ScatterElementsKernel<T, Tin><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
rank, input_data, buffer_input_dims, buffer_input_strides,
|
||||
indices_data, indices_size, buffer_indices_dims, fdm_indices_strides,
|
||||
updates, axis, output_data, func);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <class T>
|
||||
struct Func_Assignment {
|
||||
__device__ __inline__ void operator()(T* a, const T* b) const {
|
||||
*a = *b;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename Tin>
|
||||
Status ScatterElementsImpl(
|
||||
const int rank,
|
||||
|
|
@ -175,60 +229,94 @@ Status ScatterElementsImpl(
|
|||
const T* updates,
|
||||
const int axis,
|
||||
T* output_data) {
|
||||
if (input_data != output_data) {
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_data, input_data, input_size * sizeof(T), cudaMemcpyDeviceToDevice, 0));
|
||||
}
|
||||
|
||||
if (indices_size > 0) {
|
||||
std::vector<int64_t> eff_input_dims;
|
||||
std::vector<int64_t> eff_indices_dims;
|
||||
int new_axis = CompactInputIndicesDims(
|
||||
rank, axis, buffer_input_dims.data_, buffer_indices_dims.data_, eff_input_dims, eff_indices_dims);
|
||||
if (eff_input_dims.size() == 2) {
|
||||
return ScatterElementsImpl2D(
|
||||
input_data, eff_input_dims, indices_data, indices_size, eff_indices_dims, updates, new_axis, output_data);
|
||||
}
|
||||
|
||||
int blocksPerGrid = gsl::narrow_cast<int>(CeilDiv(indices_size, GridDim::maxThreadsPerBlock));
|
||||
_ScatterElementsKernel<T, Tin><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
rank, input_data, buffer_input_dims, buffer_input_strides,
|
||||
indices_data, indices_size, buffer_indices_dims, fdm_indices_strides,
|
||||
updates, axis, output_data);
|
||||
}
|
||||
return Status::OK();
|
||||
return ScatterElementsImplInternal(rank, input_data, input_size, buffer_input_dims,
|
||||
buffer_input_strides, indices_data, indices_size, buffer_indices_dims, fdm_indices_strides,
|
||||
updates, axis, output_data, Func_Assignment<T>());
|
||||
}
|
||||
|
||||
#define SPECIALIZED_TINDEX_IMPL(T, TIndex) \
|
||||
template Status ScatterElementsImpl<T, TIndex>( \
|
||||
const int rank, \
|
||||
const T* input_data, \
|
||||
const int64_t input_size, \
|
||||
TArray<int64_t>& buffer_input_dims, \
|
||||
TArray<int64_t>& buffer_input_strides, \
|
||||
const TIndex* indices_data, \
|
||||
const int64_t indices_size, \
|
||||
TArray<int64_t>& buffer_indices_dims, \
|
||||
TArray<fast_divmod>& indices_strides, \
|
||||
const T* updates, \
|
||||
const int axis, \
|
||||
#define SCATTER_ELEMENTS_SPECIALIZED_TINDEX_IMPL(T, TIndex) \
|
||||
template Status ScatterElementsImpl<T, TIndex>( \
|
||||
const int rank, \
|
||||
const T* input_data, \
|
||||
const int64_t input_size, \
|
||||
TArray<int64_t>& buffer_input_dims, \
|
||||
TArray<int64_t>& buffer_input_strides, \
|
||||
const TIndex* indices_data, \
|
||||
const int64_t indices_size, \
|
||||
TArray<int64_t>& buffer_indices_dims, \
|
||||
TArray<fast_divmod>& indices_strides, \
|
||||
const T* updates, \
|
||||
const int axis, \
|
||||
T* output_data)
|
||||
|
||||
#define SPECIALIZED_IMPL(T) \
|
||||
SPECIALIZED_TINDEX_IMPL(T, int32_t); \
|
||||
SPECIALIZED_TINDEX_IMPL(T, int64_t);
|
||||
#define SCATTER_ELEMENTS_SPECIALIZED_IMPL(T) \
|
||||
SCATTER_ELEMENTS_SPECIALIZED_TINDEX_IMPL(T, int32_t); \
|
||||
SCATTER_ELEMENTS_SPECIALIZED_TINDEX_IMPL(T, int64_t);
|
||||
|
||||
SPECIALIZED_IMPL(int8_t)
|
||||
SPECIALIZED_IMPL(int16_t)
|
||||
SPECIALIZED_IMPL(int32_t)
|
||||
SPECIALIZED_IMPL(int64_t)
|
||||
SPECIALIZED_IMPL(uint8_t)
|
||||
SPECIALIZED_IMPL(uint16_t)
|
||||
SPECIALIZED_IMPL(uint32_t)
|
||||
SPECIALIZED_IMPL(uint64_t)
|
||||
SPECIALIZED_IMPL(half)
|
||||
SPECIALIZED_IMPL(float)
|
||||
SPECIALIZED_IMPL(double)
|
||||
SPECIALIZED_IMPL(bool)
|
||||
SCATTER_ELEMENTS_SPECIALIZED_IMPL(int8_t)
|
||||
SCATTER_ELEMENTS_SPECIALIZED_IMPL(int16_t)
|
||||
SCATTER_ELEMENTS_SPECIALIZED_IMPL(int32_t)
|
||||
SCATTER_ELEMENTS_SPECIALIZED_IMPL(int64_t)
|
||||
SCATTER_ELEMENTS_SPECIALIZED_IMPL(uint8_t)
|
||||
SCATTER_ELEMENTS_SPECIALIZED_IMPL(uint16_t)
|
||||
SCATTER_ELEMENTS_SPECIALIZED_IMPL(uint32_t)
|
||||
SCATTER_ELEMENTS_SPECIALIZED_IMPL(uint64_t)
|
||||
SCATTER_ELEMENTS_SPECIALIZED_IMPL(half)
|
||||
SCATTER_ELEMENTS_SPECIALIZED_IMPL(float)
|
||||
SCATTER_ELEMENTS_SPECIALIZED_IMPL(double)
|
||||
SCATTER_ELEMENTS_SPECIALIZED_IMPL(bool)
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
|
||||
template <class T>
|
||||
struct Func_AtomicAdd {
|
||||
__device__ __inline__ void operator()(T* a, const T* b) const {
|
||||
atomic_add(a, *b);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename Tin>
|
||||
Status GatherElementsGradImpl(
|
||||
const int rank,
|
||||
TArray<int64_t>& buffer_input_dims,
|
||||
TArray<int64_t>& buffer_input_strides,
|
||||
const Tin* indices_data,
|
||||
const int64_t indices_size,
|
||||
TArray<int64_t>& buffer_indices_dims,
|
||||
TArray<fast_divmod>& fdm_indices_strides,
|
||||
const T* updates,
|
||||
const int axis,
|
||||
T* output_data) {
|
||||
// Give output_data as the input_data parameter by intention,
|
||||
// to skip input_data copy, which is not applicable for GatherElementsGrad.
|
||||
return ScatterElementsImplInternal(rank, output_data, 0,
|
||||
buffer_input_dims, buffer_input_strides, indices_data,
|
||||
indices_size, buffer_indices_dims, fdm_indices_strides,
|
||||
updates, axis, output_data, Func_AtomicAdd<T>());
|
||||
}
|
||||
|
||||
#define GATHER_ELEMENTS_GRAD_SPECIALIZED_TINDEX_IMPL(T, TIndex) \
|
||||
template Status GatherElementsGradImpl<T, TIndex>( \
|
||||
const int rank, \
|
||||
TArray<int64_t>& buffer_input_dims, \
|
||||
TArray<int64_t>& buffer_input_strides, \
|
||||
const TIndex* indices_data, \
|
||||
const int64_t indices_size, \
|
||||
TArray<int64_t>& buffer_indices_dims, \
|
||||
TArray<fast_divmod>& indices_strides, \
|
||||
const T* updates, \
|
||||
const int axis, \
|
||||
T* output_data)
|
||||
|
||||
#define GATHER_ELEMENTS_GRAD_SPECIALIZED_SCATTER_ADD_IMPL(T) \
|
||||
GATHER_ELEMENTS_GRAD_SPECIALIZED_TINDEX_IMPL(T, int32_t); \
|
||||
GATHER_ELEMENTS_GRAD_SPECIALIZED_TINDEX_IMPL(T, int64_t);
|
||||
|
||||
GATHER_ELEMENTS_GRAD_SPECIALIZED_SCATTER_ADD_IMPL(half)
|
||||
GATHER_ELEMENTS_GRAD_SPECIALIZED_SCATTER_ADD_IMPL(float)
|
||||
GATHER_ELEMENTS_GRAD_SPECIALIZED_SCATTER_ADD_IMPL(double)
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -26,4 +26,3 @@ Status ScatterElementsImpl(
|
|||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@
|
|||
#include "core/optimizer/reshape_fusion.h"
|
||||
#include "core/optimizer/attention_fusion.h"
|
||||
#include "core/optimizer/fast_gelu_fusion.h"
|
||||
#include "core/optimizer/expand_elimination.h"
|
||||
#include "core/optimizer/cast_elimination.h"
|
||||
#include "core/optimizer/utils.h"
|
||||
#include "core/platform/env.h"
|
||||
|
|
@ -1115,6 +1116,24 @@ TEST_F(GraphTransformationTests, ReshapeFusionInternalReuseTest) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, ExpandElimination) {
|
||||
auto model_uri = MODEL_FOLDER "expand_elimination.onnx";
|
||||
std::shared_ptr<Model> model;
|
||||
ASSERT_TRUE(Model::Load(model_uri, model, nullptr, *logger_).IsOK());
|
||||
Graph& graph = model->MainGraph();
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Expand"] == 6);
|
||||
|
||||
auto rule_transformer_L1 = onnxruntime::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
|
||||
rule_transformer_L1->Register(onnxruntime::make_unique<ExpandElimination>());
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
|
||||
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_).IsOK());
|
||||
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Expand"] == 3);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, CastElimination) {
|
||||
auto model_uri = MODEL_FOLDER "cast_elimination.onnx";
|
||||
std::shared_ptr<Model> model;
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/test_training_model_0.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/test_training_model_0.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/test_training_model_1.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/test_training_model_1.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/test_training_model_2.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/test_training_model_2.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/expand_elimination.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/expand_elimination.onnx
vendored
Normal file
Binary file not shown.
56
onnxruntime/test/testdata/transform/expand_elimination.py
vendored
Normal file
56
onnxruntime/test/testdata/transform/expand_elimination.py
vendored
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
import onnx
|
||||
from onnx import helper
|
||||
from onnx import TensorProto, GraphProto, OperatorSetIdProto
|
||||
from onnx import numpy_helper
|
||||
import numpy as np
|
||||
|
||||
X1 = helper.make_tensor_value_info('input1', TensorProto.FLOAT, [2, 1])
|
||||
X2 = helper.make_tensor_value_info('input2', TensorProto.FLOAT, ['dynamic', 4])
|
||||
Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 2, 4])
|
||||
|
||||
shape_constant1 = numpy_helper.from_array(np.array([1, 4], dtype=np.int64), name='shape_constant1')
|
||||
shape_constant2 = numpy_helper.from_array(np.array([1, 1], dtype=np.int64), name='shape_constant2')
|
||||
shape_constant3 = numpy_helper.from_array(np.array([2, 1], dtype=np.int64), name='shape_constant3')
|
||||
shape_constant4 = numpy_helper.from_array(np.array([1, 1, 1], dtype=np.int64), name='shape_constant4')
|
||||
shape_constant5 = numpy_helper.from_array(np.array([1, 4], dtype=np.int64), name='shape_constant5')
|
||||
shape_constant6 = numpy_helper.from_array(np.array([2, 1], dtype=np.int64), name='shape_constant6')
|
||||
|
||||
identity1 = helper.make_node('Identity', ['input1'], ['identity1'], name='identity1')
|
||||
expand1 = helper.make_node('Expand', ['identity1', shape_constant1.name], ['expand1'], name='expand1')
|
||||
expand2 = helper.make_node('Expand', ['identity1', shape_constant2.name], ['expand2'], name='expand2')
|
||||
mul1 = helper.make_node('Mul', ['expand1', 'expand2'], ['mul1'], name='mul1') # (2, 4)
|
||||
expand3 = helper.make_node('Expand', ['mul1', shape_constant3.name], ['expand3'], name='expand3')
|
||||
expand4 = helper.make_node('Expand', ['identity1', shape_constant4.name], ['expand4'], name='expand4')
|
||||
mul2 = helper.make_node('Mul', ['expand3', 'expand4'], ['mul2'], name='mul2') # (1, 2, 4)
|
||||
identity2 = helper.make_node('Identity', ['input2'], ['identity2'], name='identity2')
|
||||
expand5 = helper.make_node('Expand', ['identity2', shape_constant5.name], ['expand5'], name='expand5')
|
||||
expand6 = helper.make_node('Expand', ['identity2', shape_constant6.name], ['expand6'], name='expand6')
|
||||
mul3 = helper.make_node('Mul', ['expand5', 'expand6'], ['mul3'], name='mul3') # (dynamic=2, 4)
|
||||
mul4 = helper.make_node('Mul', ['mul2', 'mul3'], ['output'], name='mul4')
|
||||
|
||||
# Create the graph (GraphProto)
|
||||
graph_def = helper.make_graph(
|
||||
[identity1, expand1, expand2, mul1, expand3, expand4, mul2, identity2, expand5, expand6, mul3, mul4],
|
||||
'expand_elimination_model',
|
||||
[X1, X2],
|
||||
[Y],
|
||||
[shape_constant1, shape_constant2, shape_constant3, shape_constant4, shape_constant5, shape_constant6]
|
||||
)
|
||||
|
||||
opsets = []
|
||||
onnxdomain = OperatorSetIdProto()
|
||||
onnxdomain.version = 12
|
||||
onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
|
||||
opsets.append(onnxdomain)
|
||||
|
||||
msdomain = OperatorSetIdProto()
|
||||
msdomain.version = 1
|
||||
msdomain.domain = 'com.microsoft'
|
||||
|
||||
opsets.append(msdomain)
|
||||
kwargs={}
|
||||
kwargs['opset_imports'] = opsets
|
||||
|
||||
# Create the model (ModelProto)
|
||||
model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs)
|
||||
onnx.save(model_def, 'expand_elimination.onnx')
|
||||
|
|
@ -573,6 +573,17 @@ IMPLEMENT_GRADIENT_BUILDER(GetGatherGradient) {
|
|||
SrcNodeAttributes())};
|
||||
}
|
||||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetGatherElementsGradient) {
|
||||
return std::vector<NodeDef>{
|
||||
NodeDef("Shape",
|
||||
{I(0)},
|
||||
{IA("x_shape")}),
|
||||
NodeDef(OpDef{"GatherElementsGrad", kMSDomain, 1},
|
||||
{GO(0), IA("x_shape"), I(1)},
|
||||
{GI(0)},
|
||||
SrcNodeAttributes())};
|
||||
};
|
||||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetReluGradient) {
|
||||
return std::vector<NodeDef>{
|
||||
NodeDef("ReluGrad",
|
||||
|
|
@ -1034,7 +1045,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetSendGradient) {
|
|||
}
|
||||
|
||||
return std::vector<NodeDef>{
|
||||
NodeDef("Recv",
|
||||
NodeDef(OpDef{"Recv", kMSDomain, 1},
|
||||
{O(0), I(1)}, // {Signal, Remote}
|
||||
out_args,
|
||||
SrcNodeAttributes())};
|
||||
|
|
@ -1046,18 +1057,38 @@ IMPLEMENT_GRADIENT_BUILDER(GetRecvGradient) {
|
|||
|
||||
std::vector<ArgDef> in_args;
|
||||
in_args.push_back(O(0)); // Signal
|
||||
in_args.push_back(I(0)); // Remote
|
||||
in_args.push_back(I(1)); // Remote
|
||||
|
||||
for (int i = 1; i < GetSrcNodeOutputSize(); ++i) {
|
||||
in_args.push_back(GO(i)); // Data
|
||||
}
|
||||
|
||||
return std::vector<NodeDef>{
|
||||
NodeDef("Send",
|
||||
NodeDef(OpDef{"Send", kMSDomain, 1},
|
||||
in_args,
|
||||
{GI(0)}, // Signal
|
||||
SrcNodeAttributes())};
|
||||
}
|
||||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetExpandGradient) {
|
||||
ArgDef a = I(0), y = O(0);
|
||||
std::vector<Dimension> a_shape = GetShape(a);
|
||||
std::vector<Dimension> y_shape = GetShape(y);
|
||||
std::vector<int64_t> a_axes;
|
||||
ComputeBroadcastBackwardAxes(a_shape, y_shape, &a_axes, nullptr);
|
||||
|
||||
std::vector<NodeDef> output;
|
||||
if (a_axes.size() > 0) {
|
||||
HandleBroadcasting(GO(0), a, GI(0), a_axes, output);
|
||||
} else {
|
||||
output.push_back(
|
||||
NodeDef("Identity",
|
||||
{GO(0)},
|
||||
{GI(0)}));
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ DECLARE_GRADIENT_BUILDER(GetGemmGradient)
|
|||
DECLARE_GRADIENT_BUILDER(GetDropoutGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetTrainableDropoutGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetGatherNDGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetGatherElementsGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetGeluGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetLayerNormalizationGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetBatchNormalizationGradient)
|
||||
|
|
@ -55,6 +56,7 @@ DECLARE_GRADIENT_BUILDER(GetFastGeluGradient)
|
|||
DECLARE_GRADIENT_BUILDER(GetWhereGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetSendGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetRecvGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetExpandGradient)
|
||||
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -61,6 +61,7 @@ void ComputeBroadcastBackwardAxes(
|
|||
}
|
||||
|
||||
std::vector<Dimension> GetShape(const ArgDef& arg_def) {
|
||||
ORT_ENFORCE(arg_def.type_proto, "During GetShape, ", arg_def.name, "'s type_proto is null.");
|
||||
std::vector<Dimension> shape;
|
||||
const auto& dims = arg_def.type_proto->tensor_type().shape().dim();
|
||||
for (auto dim = dims.begin(); dim < dims.end(); dim++) {
|
||||
|
|
|
|||
|
|
@ -83,6 +83,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
|
|||
REGISTER_GRADIENT_BUILDER("Dropout", GetDropoutGradient)
|
||||
REGISTER_GRADIENT_BUILDER("TrainableDropout", GetTrainableDropoutGradient)
|
||||
REGISTER_GRADIENT_BUILDER("GatherND", GetGatherNDGradient)
|
||||
REGISTER_GRADIENT_BUILDER("GatherElements", GetGatherElementsGradient)
|
||||
REGISTER_GRADIENT_BUILDER("Gelu", GetGeluGradient)
|
||||
REGISTER_GRADIENT_BUILDER("LayerNormalization", GetLayerNormalizationGradient);
|
||||
REGISTER_GRADIENT_BUILDER("BatchNormalization", GetBatchNormalizationGradient);
|
||||
|
|
@ -93,6 +94,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
|
|||
REGISTER_GRADIENT_BUILDER("Where", GetWhereGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Send", GetSendGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Recv", GetRecvGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Expand", GetExpandGradient);
|
||||
};
|
||||
|
||||
} // namespace training
|
||||
|
|
|
|||
|
|
@ -432,6 +432,43 @@ void RegisterGradientSchemas() {
|
|||
{"tensor(int32)", "tensor(int64)"},
|
||||
"Constrain indices to integer types");
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(GatherElementsGrad)
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
.SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL)
|
||||
.SetDoc("GatherElementsGrad")
|
||||
.Attr(
|
||||
"axis",
|
||||
"Which axis to scatter on. Negative value means "
|
||||
"counting dimensions from the back. Accepted range is [-r, r-1] where r = rank(data).",
|
||||
AttributeProto::INT,
|
||||
static_cast<int64_t>(0))
|
||||
.Input(
|
||||
0,
|
||||
"dY",
|
||||
"Tensor of rank r >=1 (same rank and shape as indices)",
|
||||
"T")
|
||||
.Input(1, "shape", "Shape of the GatherElements input data.", "I")
|
||||
.Input(
|
||||
2,
|
||||
"indices",
|
||||
"Tensor of int32/int64 indices, of r >= 1 (same rank as input). All index values are expected to be "
|
||||
"within bounds [-s, s-1] along axis of size s. It is an error if any of the index values are out of bounds.",
|
||||
"Tind")
|
||||
.Output(0, "dX", "Tensor of rank r >= 1 (same rank as input).", "T")
|
||||
.TypeConstraint(
|
||||
"I",
|
||||
{"tensor(int64)"},
|
||||
"Constrain input shape to integer tensors.")
|
||||
.TypeConstraint(
|
||||
"T",
|
||||
{"tensor(float16)", "tensor(float)", "tensor(double)"},
|
||||
"Input and output types can be of any tensor type.")
|
||||
.TypeConstraint(
|
||||
"Tind",
|
||||
{"tensor(int32)", "tensor(int64)"},
|
||||
"Constrain indices to integer types");
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(DivGrad)
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
|
|
@ -1633,7 +1670,7 @@ Return true if all elements are true and false otherwise.
|
|||
"Allow inputs and outputs to be any kind of tensor.")
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
if (ctx.getNumInputs() < ctx.getNumOutputs() + 1)
|
||||
fail_shape_inference("WaitEvent must have at least (num_outputs + 1) inputs.");
|
||||
fail_shape_inference("RecordEvent must have at least (num_outputs + 1) inputs.");
|
||||
|
||||
// note: if num_input > num_output + 1,
|
||||
// the additional inputs (idx >= num_ouput + 1) are regarded as dependencies
|
||||
|
|
|
|||
279
orttraining/orttraining/core/graph/pipeline_transformer.cc
Normal file
279
orttraining/orttraining/core/graph/pipeline_transformer.cc
Normal file
|
|
@ -0,0 +1,279 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "orttraining/core/graph/pipeline_transformer.h"
|
||||
|
||||
using namespace onnxruntime::common;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace training {
|
||||
|
||||
void GetPipelineSendOutput(const Graph& graph, std::string& loss_name) {
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (!node.OpType().compare("Send")) {
|
||||
// send op should always have an output, which is the OutputSignal.
|
||||
loss_name = node.OutputDefs()[0]->Name();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool IsBackward(Node& node) {
|
||||
return (node.Description() == "Backward pass");
|
||||
}
|
||||
|
||||
void AddInputEvent(Graph& graph, const std::string& op_name,
|
||||
bool is_forward,
|
||||
std::vector<NodeArg*>& input_args,
|
||||
std::vector<std::string>& new_input_names) {
|
||||
ONNX_NAMESPACE::TypeProto event_type_proto;
|
||||
event_type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
|
||||
|
||||
auto event_id_name = graph.GenerateNodeArgName(op_name + (is_forward ? "_fw" : "_bw") + "_event_id");
|
||||
auto& event_id = graph.GetOrCreateNodeArg(event_id_name, &event_type_proto);
|
||||
new_input_names.push_back(event_id_name);
|
||||
|
||||
input_args.push_back(&event_id);
|
||||
}
|
||||
|
||||
// gradient graph can contain some dangling leaf nodes. Add them all to WaitEvent
|
||||
// backward node's input.
|
||||
void FindLeafNodes(Graph& graph, std::vector<NodeArg*>& input_args) {
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (!IsBackward(node)) {
|
||||
// only check backward node
|
||||
continue;
|
||||
}
|
||||
bool find_consumer_nodes = false;
|
||||
std::vector<NodeArg*>& outputs = node.MutableOutputDefs();
|
||||
for (auto& output : outputs) {
|
||||
std::vector<const Node*> consumer_nodes = graph.GetConsumerNodes(output->Name());
|
||||
if (consumer_nodes.size() > 0) {
|
||||
find_consumer_nodes = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!find_consumer_nodes && outputs.size() > 0) {
|
||||
input_args.push_back(outputs[0]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
NodeArg& CreateNodeArg(Graph& graph, const NodeArg* base_arg) {
|
||||
const auto& new_name = graph.GenerateNodeArgName(base_arg->Name());
|
||||
ONNX_NAMESPACE::TypeProto type_proto(*(base_arg->TypeAsProto()));
|
||||
if (graph.GetNodeArg(new_name) != nullptr) {
|
||||
ORT_THROW("Node with name ", new_name, " already exists.");
|
||||
}
|
||||
return graph.GetOrCreateNodeArg(new_name, &type_proto);
|
||||
}
|
||||
|
||||
Status AddRecordBackward(Graph& graph,
|
||||
Node* send_bw,
|
||||
std::vector<std::string>& new_input_names,
|
||||
std::vector<std::string>& new_output_names) {
|
||||
std::vector<NodeArg*> input_args;
|
||||
AddInputEvent(graph, "RecordEvent", false /* is_forward */, input_args, new_input_names);
|
||||
std::vector<NodeArg*> output_args{};
|
||||
|
||||
if (send_bw) {
|
||||
// if we have send op in backward pass (at the end of the graph), we make sure the RecordEvent happens
|
||||
// after that send by adding Send's outputs to RecordEvent's input list.
|
||||
input_args.insert(std::end(input_args),
|
||||
std::begin(send_bw->MutableOutputDefs()),
|
||||
std::end(send_bw->MutableOutputDefs()));
|
||||
}
|
||||
FindLeafNodes(graph, input_args);
|
||||
|
||||
// Optimizer will be added after applying pipeline transformer. To support partial graph evaluation,
|
||||
// the added Record backward op will have its first passthrough input as output.
|
||||
ORT_RETURN_IF_NOT(input_args.size() >= 2, "RecordEvent backward op at least have two inputs.")
|
||||
auto& new_output = CreateNodeArg(graph, input_args[1]); // the first input is signal, not passing through
|
||||
output_args.push_back(&new_output);
|
||||
new_output_names.push_back(new_output.Name());
|
||||
|
||||
graph.AddNode(graph.GenerateNodeName("RecordEvent"),
|
||||
"RecordEvent",
|
||||
"Backward pass",
|
||||
input_args,
|
||||
output_args,
|
||||
nullptr,
|
||||
kMSDomain);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AddWaitForward(Graph& graph, Node* /* recv_fw */, std::vector<std::string>& new_input_names) {
|
||||
// Append old_input to input_args and return its pass-through value. Note that
|
||||
// input_args and output_args are Wait's inputs and outputs, respectively.
|
||||
auto update_wait_input_output = [&](NodeArg* old_input,
|
||||
std::vector<NodeArg*>& input_args,
|
||||
std::vector<NodeArg*>& output_args) -> NodeArg& {
|
||||
input_args.push_back(old_input);
|
||||
|
||||
const auto& new_name = graph.GenerateNodeArgName(old_input->Name());
|
||||
ONNX_NAMESPACE::TypeProto type_proto(*(old_input->TypeAsProto()));
|
||||
|
||||
auto& wait_output = graph.GetOrCreateNodeArg(new_name, &type_proto);
|
||||
output_args.push_back(&wait_output);
|
||||
|
||||
return wait_output;
|
||||
};
|
||||
|
||||
std::vector<NodeArg*> input_args;
|
||||
std::vector<NodeArg*> output_args;
|
||||
AddInputEvent(graph, "WaitEvent", true /* is_forward */, input_args, new_input_names);
|
||||
const std::vector<const NodeArg*>& graph_inputs = graph.GetInputsIncludingInitializers();
|
||||
|
||||
if (graph_inputs.size() == 0){
|
||||
ORT_THROW("Graph ", graph.Name(), " doesn't have any inputs.");
|
||||
}
|
||||
|
||||
for (auto& input_arg : graph_inputs) {
|
||||
NodeArg* mutable_input = graph.GetNodeArg(input_arg->Name());
|
||||
auto& wait_output = update_wait_input_output(mutable_input, input_args, output_args);
|
||||
std::vector<Node*> nodes = graph.GetMutableConsumerNodes(input_arg->Name());
|
||||
for (auto& consumer_node : nodes) {
|
||||
for (auto& i : consumer_node->MutableInputDefs()) {
|
||||
if (i->Name() == input_arg->Name()) {
|
||||
// if the node is fed by input, re-direct it to be fed by WaitEvent's output.
|
||||
i = &wait_output;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
graph.AddNode(graph.GenerateNodeName("WaitEvent"),
|
||||
"WaitEvent",
|
||||
"",
|
||||
input_args,
|
||||
output_args,
|
||||
nullptr,
|
||||
kMSDomain);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AddOrSkipRecordForwardWaitBackward(Graph& graph, Node* send_fw, Node* recv_bw, std::vector<std::string>& new_input_names) {
|
||||
if (!send_fw != !recv_bw){
|
||||
ORT_THROW("Graph requires either having both send forward node "
|
||||
"and recv backword node, or none of them. Currently the graph "
|
||||
"has send forward: ", send_fw, " and recv backward: ", recv_bw);
|
||||
}
|
||||
|
||||
if (!send_fw && !recv_bw){
|
||||
// Last partition doesn't have send forwrad and recv backward. No insert needed.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// if we have a send forward op followed by a recv backward op, insert WaitEvent and RecordEvent in between.
|
||||
Node* record_node = nullptr;
|
||||
Node* wait_node = nullptr;
|
||||
|
||||
// Insert RecordEvent
|
||||
{
|
||||
std::vector<NodeArg*> input_args;
|
||||
std::vector<NodeArg*> output_args;
|
||||
AddInputEvent(graph, "RecordEvent", true /* is_forward */, input_args, new_input_names);
|
||||
|
||||
// Add send forward op's output as record op's input and output
|
||||
for (auto& output : send_fw->MutableOutputDefs()) {
|
||||
auto& new_output = CreateNodeArg(graph, output);
|
||||
output_args.push_back(&new_output);
|
||||
input_args.push_back(output);
|
||||
}
|
||||
|
||||
auto& new_node = graph.AddNode(graph.GenerateNodeName("RecordEvent"),
|
||||
"RecordEvent",
|
||||
"",
|
||||
input_args,
|
||||
output_args, /* output */
|
||||
{}, /* attribute */
|
||||
kMSDomain);
|
||||
record_node = &new_node;
|
||||
}
|
||||
// Insert WaitEvent
|
||||
{
|
||||
std::vector<NodeArg*> input_args;
|
||||
std::vector<NodeArg*> output_args;
|
||||
AddInputEvent(graph, "WaitEvent", false /* is_forward */, input_args, new_input_names);
|
||||
|
||||
input_args.insert(std::end(input_args),
|
||||
std::begin(record_node->MutableOutputDefs()),
|
||||
std::end(record_node->MutableOutputDefs()));
|
||||
|
||||
auto& input = recv_bw->MutableInputDefs()[0];
|
||||
auto& new_output = CreateNodeArg(graph, input);
|
||||
output_args.push_back(&new_output);
|
||||
input = &new_output;
|
||||
|
||||
auto& new_node = graph.AddNode(graph.GenerateNodeName("WaitEvent"),
|
||||
"WaitEvent",
|
||||
"Backward pass",
|
||||
input_args,
|
||||
output_args, /* output */
|
||||
{}, /* attribute */
|
||||
kMSDomain);
|
||||
wait_node = &new_node;
|
||||
ORT_UNUSED_PARAMETER(wait_node);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TransformGraphForPipeline(Graph& graph) {
|
||||
// insert WaitEvent and RecordEvent to the partition
|
||||
Node* send_fw{nullptr};
|
||||
Node* send_bw{nullptr};
|
||||
Node* recv_fw{nullptr};
|
||||
Node* recv_bw{nullptr};
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node.OpType() == "Send") {
|
||||
if (IsBackward(node)) {
|
||||
send_bw = &node;
|
||||
} else {
|
||||
send_fw = &node;
|
||||
}
|
||||
} else if (node.OpType() == "Recv") {
|
||||
if (IsBackward(node)) {
|
||||
recv_bw = &node;
|
||||
} else {
|
||||
recv_fw = &node;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> new_input_names;
|
||||
std::vector<std::string> new_output_names;
|
||||
|
||||
ORT_RETURN_IF_ERROR(AddRecordBackward(graph, send_bw, new_input_names, new_output_names));
|
||||
ORT_RETURN_IF_ERROR(AddWaitForward(graph, recv_fw, new_input_names));
|
||||
ORT_RETURN_IF_ERROR(AddOrSkipRecordForwardWaitBackward(graph, send_fw, recv_bw, new_input_names));
|
||||
|
||||
auto fill_node_args = [&](const Graph& graph,
|
||||
const std::vector<const NodeArg*>& existed_node_args,
|
||||
std::vector<std::string>& new_node_arg_names,
|
||||
std::vector<const NodeArg*>& merged_node_args) {
|
||||
merged_node_args.insert(merged_node_args.end(), existed_node_args.begin(), existed_node_args.end());
|
||||
for (auto& name : new_node_arg_names) {
|
||||
merged_node_args.push_back(graph.GetNodeArg(name));
|
||||
}
|
||||
};
|
||||
|
||||
const std::vector<const NodeArg*>& graph_inputs = graph.GetInputsIncludingInitializers();
|
||||
std::vector<const NodeArg*> inputs_args_sets;
|
||||
inputs_args_sets.reserve(graph_inputs.size() + new_input_names.size());
|
||||
fill_node_args(graph, graph_inputs, new_input_names, inputs_args_sets);
|
||||
|
||||
const std::vector<const NodeArg*>& graph_outputs = graph.GetOutputs();
|
||||
std::vector<const NodeArg*> outputs_args_sets;
|
||||
outputs_args_sets.reserve(graph_outputs.size() + new_output_names.size());
|
||||
fill_node_args(graph, graph_outputs, new_output_names, outputs_args_sets);
|
||||
|
||||
graph.SetInputs(inputs_args_sets);
|
||||
graph.SetOutputs(outputs_args_sets);
|
||||
graph.SetGraphResolveNeeded();
|
||||
graph.SetGraphProtoSyncNeeded();
|
||||
return graph.Resolve();
|
||||
}
|
||||
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
15
orttraining/orttraining/core/graph/pipeline_transformer.h
Normal file
15
orttraining/orttraining/core/graph/pipeline_transformer.h
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/graph/graph.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace training {
|
||||
|
||||
void GetPipelineSendOutput(const Graph& graph, std::string& loss_name);
|
||||
common::Status TransformGraphForPipeline(Graph& graph);
|
||||
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -12,6 +12,7 @@
|
|||
#include "core/optimizer/conv_add_fusion.h"
|
||||
#include "core/optimizer/constant_folding.h"
|
||||
#include "core/optimizer/unsqueeze_elimination.h"
|
||||
#include "core/optimizer/expand_elimination.h"
|
||||
#include "core/optimizer/cast_elimination.h"
|
||||
#include "core/optimizer/rule_based_graph_transformer.h"
|
||||
#include "core/optimizer/conv_activation_fusion.h"
|
||||
|
|
@ -54,6 +55,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(T
|
|||
rule_transformer->Register(make_unique<InsertMaxPoolOutput>());
|
||||
rule_transformer->Register(make_unique<AdjustBatchNormOutputs>());
|
||||
rule_transformer->Register(make_unique<UnsqueezeElimination>());
|
||||
rule_transformer->Register(make_unique<ExpandElimination>());
|
||||
rule_transformer->Register(make_unique<CastElimination>());
|
||||
rule_transformer->Register(make_unique<InsertSoftmaxCrossEntropyLossOutput>());
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@
|
|||
#include "core/optimizer/rule_based_graph_transformer.h"
|
||||
#include "orttraining/core/graph/mixed_precision_transformer.h"
|
||||
#include "orttraining/core/graph/tensorboard_transformer.h"
|
||||
#include "orttraining/core/graph/pipeline_transformer.h"
|
||||
#include "orttraining/core/graph/gradient_builder_base.h"
|
||||
|
||||
//Gist Encoding
|
||||
|
|
@ -128,15 +129,23 @@ Status TrainingSession::ConfigureForTraining(
|
|||
is_mixed_precision_enabled_ = config.mixed_precision_config.has_value();
|
||||
|
||||
std::string loss_name{};
|
||||
const optional<LossFunctionInfo> loss_function_info =
|
||||
config.loss_function_config.has_value()
|
||||
? config.loss_function_config.value().loss_function_info
|
||||
: optional<LossFunctionInfo>{};
|
||||
optional<std::string> loss_scale_input_name =
|
||||
is_mixed_precision_enabled_ ? optional<std::string>{""} : optional<std::string>{};
|
||||
ORT_RETURN_IF_ERROR(ConfigureLossFunction(
|
||||
config.loss_name, loss_function_info,
|
||||
loss_scale_input_name.has_value() ? &loss_scale_input_name.value() : nullptr, loss_name));
|
||||
is_mixed_precision_enabled_ ? optional<std::string>{""} : optional<std::string>{};
|
||||
if (config.use_pipeline) {
|
||||
// if use pipeline, first check if model contains send op. If it does, set the
|
||||
// send node's output as the start tensor to build gradient graph
|
||||
GetPipelineSendOutput(model_->MainGraph(), loss_name);
|
||||
}
|
||||
if (loss_name.empty()) {
|
||||
const optional<LossFunctionInfo> loss_function_info =
|
||||
config.loss_function_config.has_value()
|
||||
? config.loss_function_config.value().loss_function_info
|
||||
: optional<LossFunctionInfo>{};
|
||||
ORT_RETURN_IF_ERROR(ConfigureLossFunction(
|
||||
config.loss_name, loss_function_info,
|
||||
loss_scale_input_name.has_value() ? &loss_scale_input_name.value() : nullptr, loss_name));
|
||||
}
|
||||
|
||||
ORT_ENFORCE(
|
||||
!loss_scale_input_name.has_value() || !loss_scale_input_name.value().empty(),
|
||||
"loss_scale_input_name should not be set to an empty string.");
|
||||
|
|
@ -170,7 +179,6 @@ Status TrainingSession::ConfigureForTraining(
|
|||
<< weight_names_stream.str();
|
||||
}
|
||||
|
||||
// add gradient graph
|
||||
ORT_RETURN_IF_ERROR(BuildGradientGraph(
|
||||
weight_names_to_train, loss_name, config.set_gradients_as_graph_outputs));
|
||||
|
||||
|
|
@ -182,12 +190,30 @@ Status TrainingSession::ConfigureForTraining(
|
|||
weight_names_to_train, mixed_precision_config.use_fp16_initializers, fp32_weight_name_to_fp16_node_arg));
|
||||
}
|
||||
|
||||
if (config.use_pipeline) {
|
||||
ORT_RETURN_IF_ERROR(InsertPipelineOps());
|
||||
}
|
||||
|
||||
// All non-float tensors are not trainable. Remove those weights.
|
||||
// TODO: this is a temp workaround for removing rank tensor before adding optimizer.
|
||||
// Re-visit after we port logic for model splitting and hence know the rank tensor name.
|
||||
for (auto it = weights_to_train_.begin(); it != weights_to_train_.end();) {
|
||||
const auto* node_arg = model_->MainGraph().GetNodeArg(*it);
|
||||
ORT_RETURN_IF_NOT(node_arg, "Failed to get NodeArg with name ", *it);
|
||||
if (node_arg->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
|
||||
it = weights_to_train_.erase(it);
|
||||
}
|
||||
else{
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
// add optimizer or gradient accumulation
|
||||
if (config.optimizer_config.has_value()) {
|
||||
OptimizerGraphConfig opt_graph_config{};
|
||||
std::unordered_map<std::string, OptimizerNodeConfig> opt_node_configs{};
|
||||
ORT_RETURN_IF_ERROR(SetupOptimizerParams(
|
||||
weight_names_to_train, fp32_weight_name_to_fp16_node_arg,
|
||||
weights_to_train_, fp32_weight_name_to_fp16_node_arg,
|
||||
loss_scale_input_name, config, opt_graph_config, opt_node_configs));
|
||||
|
||||
TrainingConfigurationResult::OptimizerConfigurationResult optimizer_config_result{};
|
||||
|
|
@ -198,7 +224,7 @@ Status TrainingSession::ConfigureForTraining(
|
|||
config_result.opt_config_result = optimizer_config_result;
|
||||
} else {
|
||||
if (config.gradient_accumulation_steps > 1) {
|
||||
ORT_RETURN_IF_ERROR(BuildAccumulationNode(weight_names_to_train));
|
||||
ORT_RETURN_IF_ERROR(BuildAccumulationNode(weights_to_train_));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -439,6 +465,11 @@ Status TrainingSession::AddTensorboard(const std::string& summary_name,
|
|||
return DoPostLoadProcessing(*model_);
|
||||
}
|
||||
|
||||
Status TrainingSession::InsertPipelineOps() {
|
||||
ORT_RETURN_IF_ERROR(TransformGraphForPipeline(model_->MainGraph()));
|
||||
return DoPostLoadProcessing(*model_);
|
||||
}
|
||||
|
||||
Status TrainingSession::ConfigureLossFunction(
|
||||
const optional<std::string>& external_loss_name,
|
||||
const optional<LossFunctionInfo>& loss_function_info,
|
||||
|
|
|
|||
|
|
@ -132,6 +132,9 @@ class TrainingSession : public InferenceSession {
|
|||
// The optimizer configuration.
|
||||
// If not provided, no optimizer is added.
|
||||
optional<OptimizerConfiguration> optimizer_config{};
|
||||
|
||||
// Whether to use pipeline in training.
|
||||
bool use_pipeline{false};
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -262,6 +265,7 @@ class TrainingSession : public InferenceSession {
|
|||
const std::vector<std::string>& norm_nodes,
|
||||
const bool dump_convergence_metrics);
|
||||
|
||||
common::Status InsertPipelineOps();
|
||||
common::Status ApplyTransformationsToMainGraph();
|
||||
|
||||
/** configure initial transformers for training */
|
||||
|
|
|
|||
|
|
@ -169,6 +169,7 @@ class TrainingRunner {
|
|||
common::Status ResetLossScaler();
|
||||
|
||||
size_t GetRound() const { return round_; }
|
||||
TrainingSession& GetSession() { return session_; }
|
||||
|
||||
private:
|
||||
Status TrainingLoop(IDataLoader& training_data_loader, IDataLoader* test_data_loader);
|
||||
|
|
|
|||
|
|
@ -510,6 +510,7 @@ def create_ort_training_session_bind_parameters(model, device, world_rank=-1, wo
|
|||
dtype_torch_to_numpy(torch_params[param].dtype), list(torch_tensor.size()),
|
||||
torch_tensor.data_ptr())
|
||||
|
||||
device_index = get_device_index(device)
|
||||
create_and_bind_grad_or_grad_accumulate_buffer(train_io_binding, torch_tensor, param, enable_grad_accumulation, device, device_index)
|
||||
|
||||
return session, train_io_binding, eval_io_binding, output_name, torch_params, output_types
|
||||
|
|
|
|||
|
|
@ -128,7 +128,7 @@ inline Status GradientChecker<X_T, Y_T, JAC_T>::ComputeTheoreticalJacobianTransp
|
|||
const size_t dy_size = y_infos[y_idx].shape.Size();
|
||||
|
||||
// Compute the theoretical Jacobians one row at a time by back propagating
|
||||
// '1.0'for each element of 'dy', while holding all other elements of 'dy' at zero.
|
||||
// '1.0' for each element of 'dy', while holding all other elements of 'dy' at zero.
|
||||
for (int c = 0; c < dy_size; ++c) { // for each value in the dy input vector
|
||||
// clear OpTester input/output/initializer
|
||||
op_session.ClearData();
|
||||
|
|
|
|||
|
|
@ -1599,6 +1599,57 @@ TEST(GradientCheckerTest, GatherNDGrad_int32_indice_unique_float_data_axis_2) {
|
|||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, GatherElementsGradWithDuplicateUpdate) {
|
||||
float max_error;
|
||||
GradientChecker<float, float, float> gradient_checker;
|
||||
OpDef op_def{"GatherElements", kOnnxDomain, 11};
|
||||
|
||||
TensorInfo data_info({3, 3}, true);
|
||||
TensorInfo indice_info({2, 3}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
|
||||
std::vector<std::vector<float>> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8, 9}, {1, 2, 0, 2, 0, 0}};
|
||||
|
||||
TensorInfo y_info({2, 3}, true);
|
||||
int64_t axis = 0;
|
||||
|
||||
gradient_checker.ComputeGradientError(op_def, {data_info, indice_info}, {y_info}, &max_error, x_datas,
|
||||
{MakeAttribute("axis", axis)});
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, GatherElementsGradWithoutDuplicateUpdate) {
|
||||
float max_error;
|
||||
GradientChecker<float, float, float> gradient_checker;
|
||||
OpDef op_def{"GatherElements", kOnnxDomain, 11};
|
||||
|
||||
TensorInfo data_info({3, 3}, true);
|
||||
TensorInfo indice_info({2, 3}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
|
||||
std::vector<std::vector<float>> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8, 9}, {1, 1, 1, 2, 2, 2}};
|
||||
|
||||
TensorInfo y_info({2, 3}, true);
|
||||
int64_t axis = 0;
|
||||
|
||||
gradient_checker.ComputeGradientError(op_def, {data_info, indice_info}, {y_info}, &max_error, x_datas,
|
||||
{MakeAttribute("axis", axis)});
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, GatherElementsGradAxisWithDuplicateUpdate) {
|
||||
float max_error;
|
||||
GradientChecker<float, float, float> gradient_checker;
|
||||
OpDef op_def{"GatherElements", kOnnxDomain, 11};
|
||||
|
||||
TensorInfo data_info({3, 3}, true);
|
||||
TensorInfo indice_info({2, 3}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
|
||||
std::vector<std::vector<float>> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8, 9}, {1, 1, 1, 1, 1, 1}};
|
||||
|
||||
TensorInfo y_info({2, 3}, true);
|
||||
int64_t axis = 1;
|
||||
|
||||
gradient_checker.ComputeGradientError(op_def, {data_info, indice_info}, {y_info}, &max_error, x_datas,
|
||||
{MakeAttribute("axis", axis)});
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, LayerNormGrad) {
|
||||
GradientChecker<float, float, float> gradient_checker;
|
||||
{
|
||||
|
|
@ -1800,6 +1851,84 @@ TEST(Synchronization, WaitAndRecordEventMany) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, ExpandGrad) {
|
||||
float max_error;
|
||||
GradientChecker<float, float, float> gradient_checker;
|
||||
OpDef op_def{"Expand"};
|
||||
|
||||
//input_shape = (2, 3, 1), target_shape = (2, 3, 4) ==> shape(result) = (2, 3, 4)
|
||||
{
|
||||
TensorInfo x_info({2, 3, 1}, true);
|
||||
TensorInfo shape_info({3}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
|
||||
std::vector<std::vector<float>> x_datas = {{1, 2, 3, 4, 5, 6}, {2, 3, 4}};
|
||||
|
||||
TensorInfo y_info({2, 3, 4}, true);
|
||||
|
||||
gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
//input_shape = (2, 3, 1), target_shape = (1, 1, 4) ==> shape(result) = (2, 3, 4)
|
||||
{
|
||||
TensorInfo x_info({2, 3, 1}, true);
|
||||
TensorInfo shape_info({3}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
|
||||
std::vector<std::vector<float>> x_datas = {{1, 2, 3, 4, 5, 6}, {1, 1, 4}};
|
||||
|
||||
TensorInfo y_info({2, 3, 4}, true);
|
||||
|
||||
gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
//input_shape = (2, 3, 1), target_shape = (4) ==> shape(result) = (2, 3, 4)
|
||||
{
|
||||
TensorInfo x_info({2, 3, 1}, true);
|
||||
TensorInfo shape_info({1}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
|
||||
std::vector<std::vector<float>> x_datas = {{1, 2, 3, 4, 5, 6}, {4}};
|
||||
|
||||
TensorInfo y_info({2, 3, 4}, true);
|
||||
|
||||
gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
//input_shape = (2, 3, 1), target_shape = (1, 1) ==> shape(result) = (2, 3, 1)
|
||||
{
|
||||
TensorInfo x_info({2, 3, 1}, true);
|
||||
TensorInfo shape_info({2}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
|
||||
std::vector<std::vector<float>> x_datas = {{1, 2, 3, 4, 5, 6}, {1, 1}};
|
||||
|
||||
TensorInfo y_info({2, 3, 1}, true);
|
||||
|
||||
gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
//input_shape = (2, 3), target_shape = (4, 5, 2, 3) ==> shape(result) = (4, 5, 2, 3)
|
||||
{
|
||||
TensorInfo x_info({2, 3}, true);
|
||||
TensorInfo shape_info({4}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
|
||||
std::vector<std::vector<float>> x_datas = {{1, 2, 3, 4, 5, 6}, {4, 5, 2, 3}};
|
||||
|
||||
TensorInfo y_info({4, 5, 2, 3}, true);
|
||||
|
||||
gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
//input_shape = (1, 2, 3), target_shape = (4, 5, 1, 1) ==> shape(result) = (4, 5, 2, 3)
|
||||
{
|
||||
TensorInfo x_info({1, 2, 3}, true);
|
||||
TensorInfo shape_info({4}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
|
||||
std::vector<std::vector<float>> x_datas = {{1, 2, 3, 4, 5, 6}, {4, 5, 1, 1}};
|
||||
|
||||
TensorInfo y_info({4, 5, 2, 3}, true);
|
||||
|
||||
gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
||||
|
|
|
|||
|
|
@ -997,6 +997,106 @@ class PipelineBatchPlanner {
|
|||
}
|
||||
};
|
||||
|
||||
// verify pipeline config can load and gradient graph can construct.
|
||||
TEST(GradientGraphBuilderTest, TrainingSession_PipelineTransform_base) {
|
||||
PathString filename_base = ORT_TSTR("testdata/test_training_model_");
|
||||
|
||||
auto load_gradient_graph = [](int stageIdx, PathString& input_file, PathString& output_file) {
|
||||
auto config = MakeBasicTrainingConfig();
|
||||
|
||||
config.use_pipeline = true;
|
||||
|
||||
PathString backprop_model_file;
|
||||
ASSERT_STATUS_OK(BuildBackPropGraph(input_file, config, backprop_model_file));
|
||||
|
||||
std::shared_ptr<Model> model;
|
||||
ASSERT_TRUE(Model::Load(backprop_model_file, model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK());
|
||||
|
||||
Graph& graph = model->MainGraph();
|
||||
auto is_backward = [](Node& node) {
|
||||
return (node.Description() == "Backward pass");
|
||||
};
|
||||
// check for wait/record node
|
||||
Node* wait_fw{nullptr};
|
||||
Node* wait_bw{nullptr};
|
||||
Node* record_fw{nullptr};
|
||||
Node* record_bw{nullptr};
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node.OpType() == "WaitEvent") {
|
||||
if (is_backward(node)) {
|
||||
wait_bw = &node;
|
||||
} else {
|
||||
wait_fw = &node;
|
||||
}
|
||||
} else if (node.OpType() == "RecordEvent") {
|
||||
if (is_backward(node)) {
|
||||
record_bw = &node;
|
||||
} else {
|
||||
record_fw = &node;
|
||||
}
|
||||
}
|
||||
}
|
||||
// every partition should have wait forward and record backward
|
||||
ASSERT_TRUE(wait_fw && record_bw);
|
||||
if (stageIdx == 2) {
|
||||
// the last partition can perform back prop right away. It won't have record
|
||||
// forward and wait backward
|
||||
ASSERT_TRUE(!record_fw && !wait_bw);
|
||||
} else {
|
||||
ASSERT_TRUE(record_fw && wait_bw);
|
||||
}
|
||||
|
||||
// check for send/recv node
|
||||
Node* send_fw{nullptr};
|
||||
Node* send_bw{nullptr};
|
||||
Node* recv_fw{nullptr};
|
||||
Node* recv_bw{nullptr};
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node.OpType() == "Send") {
|
||||
if (is_backward(node)) {
|
||||
send_bw = &node;
|
||||
} else {
|
||||
send_fw = &node;
|
||||
}
|
||||
} else if (node.OpType() == "Recv") {
|
||||
if (is_backward(node)) {
|
||||
recv_bw = &node;
|
||||
} else {
|
||||
recv_fw = &node;
|
||||
}
|
||||
}
|
||||
}
|
||||
// except the last partion, each partition should have send forward and recv backward
|
||||
if (stageIdx == 0 || stageIdx == 1) {
|
||||
ASSERT_TRUE(send_fw && recv_bw);
|
||||
} else {
|
||||
ASSERT_TRUE(!send_fw && !recv_bw);
|
||||
}
|
||||
// except the first partion, each partition should have recv forward and send backward
|
||||
if (stageIdx == 1 || stageIdx == 2) {
|
||||
ASSERT_TRUE(recv_fw && send_bw);
|
||||
} else {
|
||||
ASSERT_TRUE(!recv_fw && !send_bw);
|
||||
}
|
||||
|
||||
auto mp = model->ToProto();
|
||||
std::ofstream ofs(output_file, std::ofstream::binary);
|
||||
mp.SerializeToOstream(&ofs);
|
||||
ofs.close();
|
||||
};
|
||||
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
#ifdef _WIN32
|
||||
auto surfix = std::to_wstring(i);
|
||||
#else
|
||||
auto surfix = std::to_string(i);
|
||||
#endif
|
||||
PathString input_file = filename_base + surfix + ORT_TSTR(".onnx");
|
||||
PathString output_file = filename_base + surfix + ORT_TSTR("_back.onnx");
|
||||
load_gradient_graph(i, input_file, output_file);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GradientGraphBuilderTest, TrainingSession_WithPipeline) {
|
||||
auto config = MakeBasicTrainingConfig();
|
||||
//config.set_gradients_as_graph_outputs = true;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,227 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
namespace test {
|
||||
|
||||
TEST(GatherElementsGrad, WithoutAxis) {
|
||||
onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain);
|
||||
test.AddInput<float>("dY", {2, 3},
|
||||
{1.0f, 1.1f, 1.2f,
|
||||
2.0f, 2.1f, 2.2f});
|
||||
std::vector<int64_t> data_shape = {3, 3};
|
||||
test.AddInput<int64_t>("data_shape", {2}, data_shape);
|
||||
test.AddInput<int64_t>("indices", {2, 3},
|
||||
{1, 0, 2,
|
||||
0, 2, 1});
|
||||
test.AddOutput<float>("dX", {3, 3},
|
||||
{2.0f, 1.1f, 0.0f,
|
||||
1.0f, 0.0f, 2.2f,
|
||||
0.0f, 2.1f, 1.2f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GatherElementsGrad, WithAxis) {
|
||||
onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain);
|
||||
test.AddAttribute<int64_t>("axis", 1);
|
||||
test.AddInput<float>("dY", {1, 2}, {1.1f, 2.1f});
|
||||
std::vector<int64_t> data_shape = {1, 5};
|
||||
test.AddInput<int64_t>("data_shape", {2}, data_shape);
|
||||
test.AddInput<int64_t>("indices", {1, 2}, {1, 3});
|
||||
test.AddOutput<float>("dX", {1, 5}, {0.0f, 1.1f, 0.0f, 2.1f, 0.0f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GatherElementsGrad, ThreeDimsWithAxis_0) {
|
||||
onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain);
|
||||
test.AddAttribute<int64_t>("axis", 0);
|
||||
|
||||
test.AddInput<float>("dY", {1, 3, 3},
|
||||
{11.0f, 12.0f, 13.0f,
|
||||
14.0f, 15.0f, 16.0f,
|
||||
17.0f, 18.0f, 19.0f});
|
||||
|
||||
std::vector<int64_t> data_shape = {1, 3, 3};
|
||||
test.AddInput<int64_t>("data_shape", {3}, data_shape);
|
||||
|
||||
// Because axis 0 is only 1 dimension it should be all zeros
|
||||
test.AddInput<int64_t>("indices", {1, 3, 3},
|
||||
{0, 0, 0,
|
||||
0, 0, 0,
|
||||
0, 0, 0});
|
||||
|
||||
test.AddOutput<float>("dX", {1, 3, 3},
|
||||
{11.0f, 12.0f, 13.0f,
|
||||
14.0f, 15.0f, 16.0f,
|
||||
17.0f, 18.0f, 19.0f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GatherElementsGrad, ThreeDimsWithAxis_2) {
|
||||
onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain);
|
||||
test.AddAttribute<int64_t>("axis", 2);
|
||||
|
||||
test.AddInput<float>("dY", {1, 3, 3},
|
||||
{11, 12, 13,
|
||||
14, 15, 16,
|
||||
17, 18, 19});
|
||||
|
||||
std::vector<int64_t> data_shape = {1, 3, 3};
|
||||
test.AddInput<int64_t>("data_shape", {3}, data_shape);
|
||||
|
||||
test.AddInput<int64_t>("indices", {1, 3, 3},
|
||||
{2, 1, 0,
|
||||
2, 1, 0,
|
||||
2, 1, 0});
|
||||
|
||||
test.AddOutput<float>("dX", {1, 3, 3},
|
||||
{13, 12, 11,
|
||||
16, 15, 14,
|
||||
19, 18, 17});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GatherElementsGrad, NegativeAxis) {
|
||||
onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain);
|
||||
test.AddAttribute<int64_t>("axis", -1);
|
||||
test.AddInput<float>("dY", {1, 2}, {1.1f, 2.1f});
|
||||
std::vector<int64_t> data_shape = {1, 5};
|
||||
test.AddInput<int64_t>("data_shape", {2}, data_shape);
|
||||
test.AddInput<int64_t>("indices", {1, 2}, {1, 3});
|
||||
test.AddOutput<float>("dX", {1, 5}, {0.0f, 1.1f, 0.0f, 2.1f, 0.0f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GatherElementsGrad, IndicesUpdatesDontMatch) {
|
||||
onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain);
|
||||
test.AddAttribute<int64_t>("axis", 1);
|
||||
test.AddInput<float>("dY", {1, 2}, {1.1f, 2.1f});
|
||||
std::vector<int64_t> data_shape = {1, 5};
|
||||
test.AddInput<int64_t>("data_shape", {2}, data_shape);
|
||||
test.AddInput<int64_t>("indices", {1, 3}, {1, 3, 3});
|
||||
test.AddOutput<float>("dX", {1, 5}, {1.0f, 3.1f, 3.0f, 6.1f, 5.0f});
|
||||
test.Run(onnxruntime::test::OpTester::ExpectResult::kExpectFailure, "Indices vs dY dimensions differs at position=1 3 vs 2");
|
||||
}
|
||||
|
||||
TEST(GatherElementsGrad, ValidAxis) {
|
||||
onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain);
|
||||
test.AddAttribute<int64_t>("axis", 0);
|
||||
test.AddInput<float>("dY", {1, 1, 1}, {5.0f});
|
||||
std::vector<int64_t> data_shape = {4, 2, 1};
|
||||
test.AddInput<int64_t>("data_shape", {3}, data_shape);
|
||||
test.AddInput<int64_t>("indices", {1, 1, 1}, {3});
|
||||
test.AddOutput<float>("dX", {4, 2, 1}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 5.0f, 0.0f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GatherElementsGrad, ValidNegativeIndex) {
|
||||
onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain);
|
||||
test.AddAttribute<int64_t>("axis", 0);
|
||||
test.AddInput<float>("dY", {1, 1, 1}, {5.0f});
|
||||
std::vector<int64_t> data_shape = {4, 2, 1};
|
||||
test.AddInput<int64_t>("data_shape", {3}, data_shape);
|
||||
test.AddInput<int64_t>("indices", {1, 1, 1}, {-1});
|
||||
test.AddOutput<float>("dX", {4, 2, 1}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 5.0f, 0.0f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GatherElementsGrad, SameUpdateWithoutAxis) {
|
||||
onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain);
|
||||
test.AddInput<float>("dY", {2, 2},
|
||||
{11.0f, 22.0f,
|
||||
33.0f, 44.0f});
|
||||
|
||||
std::vector<int64_t> data_shape = {3, 3};
|
||||
test.AddInput<int64_t>("data_shape", {2}, data_shape);
|
||||
|
||||
test.AddInput<int32_t>("indices", {2, 2},
|
||||
{1, 1,
|
||||
1, 1},
|
||||
true);
|
||||
|
||||
test.AddOutput<float>("dX", {3, 3},
|
||||
{0.0f, 0.0f, 0.0f,
|
||||
44.0f, 66.0f, 0.0f,
|
||||
0.0f, 0.0f, 0.0f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GatherElementsGrad, SameUpdateWithAxis) {
|
||||
onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain);
|
||||
test.AddAttribute<int64_t>("axis", 1);
|
||||
test.AddInput<float>("dY", {2, 3},
|
||||
{11.0f, 22.0f, 33.0f,
|
||||
44.0f, 55.0f, 66.0f});
|
||||
|
||||
std::vector<int64_t> data_shape = {3, 3};
|
||||
test.AddInput<int64_t>("data_shape", {2}, data_shape);
|
||||
|
||||
test.AddInput<int32_t>("indices", {2, 3},
|
||||
{1, 1, 1,
|
||||
1, 1, 1},
|
||||
true);
|
||||
|
||||
test.AddOutput<float>("dX", {3, 3},
|
||||
{0.0f, 66.0f, 0.0f,
|
||||
0.0f, 165.0f, 0.0f,
|
||||
0.0f, 0.0f, 0.0f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GatherElementsGrad, SameUpdateWithNegativeAxis) {
|
||||
onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain);
|
||||
test.AddAttribute<int64_t>("axis", -1);
|
||||
test.AddInput<float>("dY", {2, 3},
|
||||
{11.0f, 22.0f, 33.0f,
|
||||
44.0f, 55.0f, 66.0f});
|
||||
|
||||
std::vector<int64_t> data_shape = {3, 3};
|
||||
test.AddInput<int64_t>("data_shape", {2}, data_shape);
|
||||
|
||||
test.AddInput<int32_t>("indices", {2, 3},
|
||||
{1, 0, 1,
|
||||
1, 0, 1},
|
||||
true);
|
||||
|
||||
test.AddOutput<float>("dX", {3, 3},
|
||||
{22.0f, 44.0f, 0.0f,
|
||||
55.0f, 110.0f, 0.0f,
|
||||
0.0f, 0.0f, 0.0f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GatherElementsGrad, SameUpdateWithoutAxisMLFloat16) {
|
||||
onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain);
|
||||
std::vector<float> update = {11.0f, 22.0f,
|
||||
33.0f, 44.0f};
|
||||
std::vector<MLFloat16> fp16_update(update.size());
|
||||
onnxruntime::test::ConvertFloatToMLFloat16(update.data(), fp16_update.data(), static_cast<int>(update.size()));
|
||||
|
||||
std::vector<float> output = {0.0f, 0.0f, 0.0f,
|
||||
44.0f, 66.0f, 0.0f,
|
||||
0.0f, 0.0f, 0.0f};
|
||||
std::vector<MLFloat16> fp16_output(output.size());
|
||||
onnxruntime::test::ConvertFloatToMLFloat16(output.data(), fp16_output.data(), static_cast<int>(output.size()));
|
||||
|
||||
test.AddInput<MLFloat16>("dY", {2, 2}, fp16_update);
|
||||
|
||||
std::vector<int64_t> data_shape = {3, 3};
|
||||
test.AddInput<int64_t>("data_shape", {2}, data_shape);
|
||||
|
||||
test.AddInput<int32_t>("indices", {2, 2},
|
||||
{1, 1,
|
||||
1, 1},
|
||||
true);
|
||||
|
||||
test.AddOutput<MLFloat16>("dX", {3, 3}, fp16_output);
|
||||
|
||||
test.Run();
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -6,7 +6,6 @@
|
|||
#pragma once
|
||||
#include "core/common/common.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cuda/cudnn_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
#pragma once
|
||||
#include "core/common/common.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cuda/cudnn_common.h"
|
||||
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
#pragma once
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/providers/cuda/cudnn_common.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cuda/multi_tensor/common.cuh"
|
||||
|
||||
constexpr int PARALLEL_LOADS = 4;
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
#pragma once
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/providers/cuda/cudnn_common.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
#pragma once
|
||||
#include "core/common/common.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cuda/cudnn_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
|
|
|||
|
|
@ -190,8 +190,8 @@ __device__ void cuWelfordMuSigma2(
|
|||
for (; l + 7 < n2; l += 8 * numx) {
|
||||
for (int k = 0; k < 8; k += 2) {
|
||||
float2 curr = __half22float2(*((__half2*)(lvals + l + k)));
|
||||
cuWelfordOnlineSum(curr.x, mu, sigma2, count);
|
||||
cuWelfordOnlineSum(curr.y, mu, sigma2, count);
|
||||
cuWelfordOnlineSum(static_cast<float>(curr.x), mu, sigma2, count);
|
||||
cuWelfordOnlineSum(static_cast<float>(curr.y), mu, sigma2, count);
|
||||
}
|
||||
}
|
||||
for (; l < n2; ++l) {
|
||||
|
|
@ -308,7 +308,7 @@ __global__ void cuApplyLayerNorm(
|
|||
// 1) blockDim.x == GPU_WARP_SIZE
|
||||
// 2) Tensors are contiguous
|
||||
//
|
||||
for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
|
||||
for (int i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
|
||||
SharedMemory<U> shared;
|
||||
U* buf = shared.getPointer();
|
||||
U mu, sigma2;
|
||||
|
|
@ -576,7 +576,7 @@ __global__ void cuComputeGradInput(
|
|||
const U* __restrict__ invvar,
|
||||
const T* gamma,
|
||||
T* grad_input) {
|
||||
for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
|
||||
for (int i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
|
||||
U sum_loss1 = U(0);
|
||||
U sum_loss2 = U(0);
|
||||
const U c_mean = mean[i1];
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
#pragma once
|
||||
#include "core/common/common.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cuda/cudnn_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
#pragma once
|
||||
#include "core/common/common.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cuda/cudnn_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
#pragma once
|
||||
#include "core/common/common.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cuda/cudnn_common.h"
|
||||
#include "core/providers/cuda/multi_tensor/common.cuh"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
#pragma once
|
||||
#include "core/common/common.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cuda/cudnn_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,137 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "orttraining/training_ops/cuda/tensor/gather_elements_grad.h"
|
||||
#include "orttraining/training_ops/cuda/tensor/gather_elements_grad_impl.h"
|
||||
#include "core/providers/cpu/tensor/utils.h"
|
||||
#include "core/providers/common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
GatherElementsGrad,
|
||||
kMSDomain,
|
||||
1,
|
||||
kCudaExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.InputMemoryType<OrtMemTypeCPUInput>(1) // 'GatherElements' data shape needs to be on CPU
|
||||
.TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes())
|
||||
.TypeConstraint("I", DataTypeImpl::GetTensorType<int64_t>())
|
||||
.TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>()}),
|
||||
GatherElementsGrad);
|
||||
|
||||
template <typename T>
|
||||
struct GatherElementsGrad::ComputeImpl {
|
||||
Status operator()(const Tensor* dY,
|
||||
const Tensor* indices_tensor,
|
||||
Tensor* dX,
|
||||
const int rank,
|
||||
TArray<int64_t>& buffer_output_dims,
|
||||
TArray<int64_t>& buffer_input_strides,
|
||||
const int64_t indices_size,
|
||||
TArray<int64_t>& buffer_indices_dims,
|
||||
TArray<fast_divmod>& fdm_indices_strides,
|
||||
const int axis) const {
|
||||
T* output_data = dX->template MutableData<T>();
|
||||
const T* update_data = dY->template Data<T>();
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
|
||||
MLDataType Tin_type = indices_tensor->DataType();
|
||||
if (utils::IsPrimitiveDataType<int32_t>(Tin_type)) {
|
||||
const int32_t* indices_data = indices_tensor->template Data<int32_t>();
|
||||
return GatherElementsGradImpl(
|
||||
rank,
|
||||
buffer_output_dims,
|
||||
buffer_input_strides,
|
||||
indices_data,
|
||||
indices_size,
|
||||
buffer_indices_dims,
|
||||
fdm_indices_strides,
|
||||
reinterpret_cast<const CudaT*>(update_data),
|
||||
axis,
|
||||
reinterpret_cast<CudaT*>(output_data));
|
||||
} else if (utils::IsPrimitiveDataType<int64_t>(Tin_type)) {
|
||||
const int64_t* indices_data = indices_tensor->template Data<int64_t>();
|
||||
return GatherElementsGradImpl(
|
||||
rank,
|
||||
buffer_output_dims,
|
||||
buffer_input_strides,
|
||||
indices_data,
|
||||
indices_size,
|
||||
buffer_indices_dims,
|
||||
fdm_indices_strides,
|
||||
reinterpret_cast<const CudaT*>(update_data),
|
||||
axis,
|
||||
reinterpret_cast<CudaT*>(output_data));
|
||||
}
|
||||
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Type for Tin is not supported yet in ScatterElements.");
|
||||
}
|
||||
};
|
||||
|
||||
Status GatherElementsGrad::ComputeInternal(OpKernelContext* context) const {
|
||||
const auto* dY = context->Input<Tensor>(0);
|
||||
const Tensor* shape = context->Input<Tensor>(1);
|
||||
const TensorShape data_shape(shape->template Data<int64_t>(), shape->Shape().Size());
|
||||
|
||||
const int axis = static_cast<int>(HandleNegativeAxis(axis_, data_shape.NumDimensions()));
|
||||
|
||||
const auto* indices_tensor = context->Input<Tensor>(2);
|
||||
|
||||
const auto& indices_dims = indices_tensor->Shape().GetDims();
|
||||
const int64_t indices_size = indices_tensor->Shape().Size();
|
||||
const auto& dY_dims = dY->Shape().GetDims();
|
||||
if (indices_dims.size() != dY_dims.size()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Indices and dY must have the same rank");
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < indices_dims.size(); ++i) {
|
||||
if (indices_dims[i] != dY_dims[i]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Indices vs dY dimensions differs at position=", i,
|
||||
" ", indices_dims[i], " vs ", dY_dims[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// According to the spec the rank of ind/upd shall be the same as output(data)
|
||||
// and we also want to make sure that the dimensions of the of the ind/upd do not
|
||||
// exceed that of the output
|
||||
const auto& output_dims = data_shape.GetDims();
|
||||
if (output_dims.size() != indices_dims.size()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Indices must have the same rank as Output. Indices rank=",
|
||||
indices_dims.size(), ". Output rank=", output_dims.size());
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < output_dims.size(); ++i) {
|
||||
if (output_dims[i] < indices_dims[i]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Indices dim=", indices_dims[i], " at pos=", i,
|
||||
" is greater than Output dim=", output_dims[i]);
|
||||
}
|
||||
}
|
||||
|
||||
int rank = static_cast<int>(output_dims.size());
|
||||
Tensor* dX = context->Output(0, data_shape);
|
||||
CUDA_RETURN_IF_ERROR(cudaMemset(dX->MutableDataRaw(), 0, dX->SizeInBytes()));
|
||||
|
||||
TArray<int64_t> buffer_output_dims(output_dims);
|
||||
TensorPitches input_strides(output_dims);
|
||||
TArray<int64_t> buffer_input_strides(input_strides);
|
||||
|
||||
TArray<int64_t> buffer_indices_dims(indices_dims);
|
||||
TArray<fast_divmod> fdm_indices_strides(rank);
|
||||
TensorPitches indices_strides(indices_dims);
|
||||
for (auto i = 0; i < rank; i++) {
|
||||
fdm_indices_strides[i] = fast_divmod(static_cast<int>(indices_strides[i]));
|
||||
}
|
||||
|
||||
utils::MLTypeCallDispatcherRet<Status, ComputeImpl, MLFloat16, float, double>
|
||||
t_disp(dY->GetElementType());
|
||||
return t_disp.Invoke(dY, indices_tensor, dX, rank,
|
||||
buffer_output_dims, buffer_input_strides, indices_size,
|
||||
buffer_indices_dims, fdm_indices_strides, axis);
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#pragma once
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
class GatherElementsGrad final : public CudaKernel {
|
||||
public:
|
||||
GatherElementsGrad(const OpKernelInfo& info) : CudaKernel(info) {
|
||||
info.GetAttrOrDefault("axis", &axis_, static_cast<int64_t>(0));
|
||||
}
|
||||
~GatherElementsGrad() = default;
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
struct ComputeImpl;
|
||||
|
||||
int64_t axis_;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
26
orttraining/orttraining/training_ops/cuda/tensor/gather_elements_grad_impl.h
Executable file
26
orttraining/orttraining/training_ops/cuda/tensor/gather_elements_grad_impl.h
Executable file
|
|
@ -0,0 +1,26 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
#include "core/providers/cuda/shared_inc/cuda_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
template <typename T, typename Tin>
|
||||
Status GatherElementsGradImpl(
|
||||
const int rank,
|
||||
TArray<int64_t>& buffer_input_dims,
|
||||
TArray<int64_t>& buffer_input_strides,
|
||||
const Tin* indices_data,
|
||||
const int64_t indices_size,
|
||||
TArray<int64_t>& buffer_indices_dims,
|
||||
TArray<fast_divmod>& indices_strides,
|
||||
const T* updates,
|
||||
const int axis,
|
||||
T* output_data);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -121,6 +121,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_float, LayerNormalizationGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_float, LayerNormalizationGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, SliceGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GatherElementsGrad);
|
||||
|
||||
#ifdef USE_HOROVOD
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, HorovodAllReduce);
|
||||
|
|
@ -252,6 +253,7 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_float, LayerNormalizationGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_float, LayerNormalizationGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, SliceGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GatherElementsGrad)>,
|
||||
|
||||
#ifdef USE_HOROVOD
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, HorovodAllReduce)>,
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ import torch.nn.functional as F
|
|||
import torch.optim as optim
|
||||
from torchvision import datasets, transforms
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
# TODO: remove after ready for CV
|
||||
# import sys
|
||||
|
|
@ -27,13 +28,15 @@ import numpy as np
|
|||
# from ort_trainer import IODescription, ModelDescription, ORTTrainer, ORTModel
|
||||
|
||||
from onnxruntime.capi.ort_trainer import IODescription, ModelDescription, ORTTrainer, ORTModel
|
||||
from mpi4py import MPI
|
||||
from onnxruntime.capi._pybind_state import set_cuda_device_id
|
||||
|
||||
class NeuralNet(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_classes):
|
||||
super(NeuralNet, self).__init__()
|
||||
self.fc1 = nn.Linear(input_size, hidden_size)
|
||||
self.fc1 = nn.Linear(input_size, hidden_size)
|
||||
self.relu = nn.ReLU()
|
||||
self.fc2 = nn.Linear(hidden_size, num_classes)
|
||||
self.fc2 = nn.Linear(hidden_size, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.fc1(x)
|
||||
|
|
@ -79,18 +82,19 @@ def test_with_model(args, model, device, test_loader, optimizer, epoch):
|
|||
test_loss, correct, len(test_loader.dataset),
|
||||
100. * correct / len(test_loader.dataset)))
|
||||
|
||||
def train_with_trainer(args, trainer, device, train_loader, epoch):
|
||||
def train_with_trainer(args, trainer, device, train_loader, epoch):
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
data, target = data.to(device), target.to(device)
|
||||
data = data.reshape(data.shape[0], -1)
|
||||
|
||||
learning_rate = torch.tensor([args.lr])
|
||||
loss = trainer.train_step((data, target, learning_rate))
|
||||
loss = trainer.train_step(data, target, learning_rate)
|
||||
|
||||
# Since the output corresponds to [loss_desc, probability_desc], the first value is taken as loss.
|
||||
if batch_idx % args.log_interval == 0:
|
||||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
||||
epoch, batch_idx * len(data), len(train_loader.dataset),
|
||||
100. * batch_idx / len(train_loader), loss.item()))
|
||||
100. * batch_idx / len(train_loader), loss[0]))
|
||||
|
||||
# TODO: comple this once ORT training can do evaluation.
|
||||
def test_with_trainer(args, trainer, device, test_loader):
|
||||
|
|
@ -152,8 +156,6 @@ def main():
|
|||
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
device = torch.device("cuda" if use_cuda else "cpu")
|
||||
|
||||
kwargs = {'num_workers': 0, 'pin_memory': True}
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
datasets.MNIST('../data', train=True, download=True,
|
||||
|
|
@ -167,6 +169,19 @@ def main():
|
|||
transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])),
|
||||
batch_size=args.test_batch_size, shuffle=True, **kwargs)
|
||||
|
||||
|
||||
comm = MPI.COMM_WORLD
|
||||
args.local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) if ('OMPI_COMM_WORLD_LOCAL_RANK' in os.environ) else 0
|
||||
args.world_rank = int(os.environ['OMPI_COMM_WORLD_RANK']) if ('OMPI_COMM_WORLD_RANK' in os.environ) else 0
|
||||
args.world_size=comm.Get_size()
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
if use_cuda:
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
args.n_gpu = 1
|
||||
set_cuda_device_id(args.local_rank)
|
||||
|
||||
input_size = 784
|
||||
hidden_size = 500
|
||||
num_classes = 10
|
||||
|
|
@ -175,14 +190,17 @@ def main():
|
|||
model_desc = mnist_model_description()
|
||||
if args.use_ort_trainer:
|
||||
# use log_interval as gradient accumulate steps
|
||||
trainer = ORTTrainer(model, my_loss, model_desc, "SGDOptimizer", None, IODescription('Learning_Rate', [1,], torch.float32), device)
|
||||
trainer = ORTTrainer(model, my_loss, model_desc, "LambOptimizer", None, IODescription('Learning_Rate', [1,], torch.float32), device, 1, None,
|
||||
args.world_rank, args.world_size, use_mixed_precision=False, allreduce_post_accumulation = True)
|
||||
print('\nBuild ort model done.')
|
||||
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
train_with_trainer(args, trainer, device, train_loader, epoch)
|
||||
import pdb
|
||||
test_with_trainer(args, trainer, device, test_loader)
|
||||
else:
|
||||
model = ORTModel(model, my_loss, model_desc, device)
|
||||
model = ORTModel(model, my_loss, model_desc, device, None, args.world_rank, args.world_size)
|
||||
print('\nBuild ort model done.')
|
||||
|
||||
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import onnx
|
|||
from onnx import helper
|
||||
from onnx import TensorProto
|
||||
from onnx import OperatorSetIdProto
|
||||
|
||||
# Edge that needs to be cut for the split.
|
||||
# If the edge is feeding into more than one nodes, and not all the nodes belong to the same cut,
|
||||
# specify those consuming nodes that need to be cut
|
||||
|
|
@ -29,31 +28,6 @@ def split_graph(model, split_edge_groups):
|
|||
|
||||
new_send_nodes = []
|
||||
new_recv_nodes = []
|
||||
# Add wait for initial inputs. This needs to be done first before new inputs
|
||||
# are introduced from split
|
||||
initializer_lists = [a.name for a in model.graph.initializer]
|
||||
input_tensors = [
|
||||
value.name for value in model.graph.input if value.name not in initializer_lists]
|
||||
|
||||
input_wait_signal = model.graph.input.add()
|
||||
input_wait_signal.CopyFrom(helper.make_tensor_value_info(
|
||||
'input_wait_signal', onnx.TensorProto.INT64, None))
|
||||
|
||||
input_wait = model.graph.node.add()
|
||||
input_wait.CopyFrom(helper.make_node(
|
||||
'WaitEvent',
|
||||
inputs=['input_wait_signal'],
|
||||
outputs=[],
|
||||
domain=ms_domain))
|
||||
|
||||
for i in input_tensors:
|
||||
for node in model.graph.node:
|
||||
for j in range(len(node.input)):
|
||||
if node.input[j] == i:
|
||||
node.input[j] = i + '_sync'
|
||||
|
||||
input_wait.input.extend(input_tensors)
|
||||
input_wait.output.extend([i + '_sync' for i in input_tensors])
|
||||
|
||||
for cut_index in range(len(split_edge_groups)):
|
||||
edgeIds = split_edge_groups[cut_index]
|
||||
|
|
@ -62,7 +36,7 @@ def split_graph(model, split_edge_groups):
|
|||
upstream_nodes = []
|
||||
upstream_nodes_output_index = []
|
||||
output_shapes = []
|
||||
|
||||
element_types = []
|
||||
for id in edgeIds:
|
||||
for node in model.graph.node:
|
||||
if len(node.output) >= 1:
|
||||
|
|
@ -70,25 +44,43 @@ def split_graph(model, split_edge_groups):
|
|||
if j == id:
|
||||
upstream_nodes.append(node)
|
||||
upstream_nodes_output_index.append(i)
|
||||
for info in model.graph.value_info:
|
||||
if info.name == id:
|
||||
output_shapes.append(info.type)
|
||||
|
||||
record_signal = model.graph.input.add()
|
||||
record_signal.CopyFrom(helper.make_tensor_value_info(
|
||||
'record_input_signal' + str(cut_index), onnx.TensorProto.INT64, None))
|
||||
|
||||
wait_signal = model.graph.input.add()
|
||||
wait_signal.CopyFrom(helper.make_tensor_value_info(
|
||||
'wait_input_signal' + str(cut_index), onnx.TensorProto.INT64, None))
|
||||
# assuming all tensors are of type float
|
||||
element_types.append(1)
|
||||
for info in model.graph.value_info:
|
||||
if info.name == id:
|
||||
output_shapes.append(info.type)
|
||||
|
||||
send_input_signal_name = 'send_input_signal' + str(cut_index)
|
||||
send_signal = model.graph.input.add()
|
||||
send_signal.CopyFrom(helper.make_tensor_value_info(
|
||||
'send_input_signal' + str(cut_index), onnx.TensorProto.BOOL, None))
|
||||
send_input_signal_name, onnx.TensorProto.BOOL, None))
|
||||
send_signal = helper.make_tensor(
|
||||
send_input_signal_name, TensorProto.BOOL, (), (True,))
|
||||
model.graph.initializer.extend([send_signal])
|
||||
|
||||
recv_input_signal_name = 'recv_input_signal' + str(cut_index)
|
||||
recv_signal = model.graph.input.add()
|
||||
recv_signal.CopyFrom(helper.make_tensor_value_info(
|
||||
'recv_input_signal' + str(cut_index), onnx.TensorProto.BOOL, None))
|
||||
recv_input_signal_name, onnx.TensorProto.BOOL, None))
|
||||
recv_signal = helper.make_tensor(
|
||||
recv_input_signal_name, TensorProto.BOOL, (), (True,))
|
||||
model.graph.initializer.extend([recv_signal])
|
||||
|
||||
send_dst_rank_name = 'send_dst_rank' + str(cut_index)
|
||||
send_dst_rank = model.graph.input.add()
|
||||
send_dst_rank.CopyFrom(helper.make_tensor_value_info(
|
||||
send_dst_rank_name, onnx.TensorProto.INT64, None))
|
||||
send_dst_rank = helper.make_tensor(
|
||||
send_dst_rank_name, TensorProto.INT64, (), (cut_index + 1,))
|
||||
model.graph.initializer.extend([send_dst_rank])
|
||||
|
||||
recv_src_rank_name = 'recv_src_rank' + str(cut_index)
|
||||
recv_src_rank = model.graph.input.add()
|
||||
recv_src_rank.CopyFrom(helper.make_tensor_value_info(
|
||||
recv_src_rank_name, onnx.TensorProto.INT64, None))
|
||||
recv_src_rank = helper.make_tensor(
|
||||
recv_src_rank_name, TensorProto.INT64, (), (cut_index,))
|
||||
model.graph.initializer.extend([recv_src_rank])
|
||||
|
||||
# output signal from send after cut
|
||||
send_output_signal = model.graph.output.add()
|
||||
|
|
@ -103,41 +95,23 @@ def split_graph(model, split_edge_groups):
|
|||
new_send = model.graph.node.add()
|
||||
new_send.CopyFrom(helper.make_node(
|
||||
'Send',
|
||||
inputs=['send_input_signal' + str(cut_index)],
|
||||
inputs=[send_input_signal_name, send_dst_rank_name],
|
||||
outputs=['send_output_signal' + str(cut_index)],
|
||||
tag=0,
|
||||
src=cut_index,
|
||||
dst=cut_index + 1,
|
||||
domain=ms_domain,
|
||||
element_type=7, # assuming all tensors are of type float
|
||||
element_types=element_types,
|
||||
name='send'))
|
||||
|
||||
new_receive = model.graph.node.add()
|
||||
new_receive.CopyFrom(helper.make_node(
|
||||
'Recv',
|
||||
inputs=['recv_input_signal' + str(cut_index)],
|
||||
inputs=[recv_input_signal_name, recv_src_rank_name],
|
||||
outputs=['receive_output_signal' + str(cut_index)],
|
||||
tag=1,
|
||||
src=cut_index,
|
||||
dst=cut_index + 1,
|
||||
tag=0,
|
||||
domain=ms_domain,
|
||||
element_type=7, # assuming all tensors are of type float
|
||||
element_types=element_types,
|
||||
name='receive'))
|
||||
|
||||
new_wait = model.graph.node.add()
|
||||
new_wait.CopyFrom(helper.make_node(
|
||||
'WaitEvent',
|
||||
inputs=['wait_input_signal' + str(cut_index)],
|
||||
outputs=[],
|
||||
domain=ms_domain))
|
||||
|
||||
new_record = model.graph.node.add()
|
||||
new_record.CopyFrom(helper.make_node(
|
||||
'RecordEvent',
|
||||
inputs=['record_input_signal' + str(cut_index)],
|
||||
outputs=[],
|
||||
domain=ms_domain))
|
||||
|
||||
for i in range(len(upstream_nodes)):
|
||||
n = upstream_nodes[i]
|
||||
idx = upstream_nodes_output_index[i]
|
||||
|
|
@ -155,24 +129,16 @@ def split_graph(model, split_edge_groups):
|
|||
'_recv' + str(cut_index)
|
||||
add_expand_type(model, new_receive_output_name, output_type)
|
||||
|
||||
new_wait_output_name = output_edge_name + '_wait' + str(cut_index)
|
||||
add_expand_type(model, new_wait_output_name, output_type)
|
||||
|
||||
# the order of data flow is: node-output -> record -> send -> recv -> wait -> node-input
|
||||
new_record.input.extend([output_edge_name])
|
||||
new_record.output.extend([new_send_input_name])
|
||||
|
||||
new_send.input.extend([new_send_input_name])
|
||||
new_send.input.extend([output_edge_name])
|
||||
new_receive.output.extend([new_receive_output_name])
|
||||
|
||||
new_wait.input.extend([new_receive_output_name])
|
||||
new_wait.output.extend([new_wait_output_name])
|
||||
|
||||
for output_node in output_nodes:
|
||||
for i in range(len(output_node.input)):
|
||||
for edgeId in edgeIds:
|
||||
if output_node.input[i] == edgeId:
|
||||
output_node.input[i] = new_wait_output_name
|
||||
output_node.input[i] = new_receive_output_name
|
||||
|
||||
new_send_nodes.append(new_send)
|
||||
new_recv_nodes.append(new_receive)
|
||||
|
|
@ -236,9 +202,50 @@ def add_identity(model, cuttingEdge, newEdgeIdName):
|
|||
if output_nodes[i].input[j] == edgeId:
|
||||
output_nodes[i].input[j] = newEdgeIdName
|
||||
|
||||
return newEdgeIdName
|
||||
return new_identity
|
||||
|
||||
|
||||
def insert_identity(model, all_cut_inputs):
|
||||
count = 0
|
||||
updated_edges = {}
|
||||
new_added_identity = []
|
||||
split_edge_groups = []
|
||||
need_shape_inference = False
|
||||
# Sweep the cut edge to see if there are edges feeding into nodes from two sub-graphs. If so,
|
||||
# insert identity node after those edges with a new ID to distinguish the rest.
|
||||
for cut_input in all_cut_inputs:
|
||||
split_edges = []
|
||||
for i in cut_input:
|
||||
if i.consumingNodes:
|
||||
# if this edge has previously been modified, update its edgeId before inserting new identity
|
||||
if i.edgeId in updated_edges:
|
||||
i.edgeId = updated_edges[i.edgeId]
|
||||
|
||||
new_edge_name = 'identity_output_' + str(count)
|
||||
new_added_identity.append(
|
||||
add_identity(model, i, new_edge_name))
|
||||
count += 1
|
||||
split_edges.append(new_edge_name)
|
||||
updated_edges[i.edgeId] = new_edge_name
|
||||
need_shape_inference = True
|
||||
else:
|
||||
split_edges.append(i.edgeId)
|
||||
split_edge_groups.append(split_edges)
|
||||
return split_edge_groups, new_added_identity, need_shape_inference
|
||||
|
||||
# after the graph is split, remove the added identity node because identity op is not registered in gradient builder.
|
||||
|
||||
|
||||
def remove_identity(model, new_added_identity):
|
||||
for node in new_added_identity:
|
||||
assert node.op_type == 'Identity'
|
||||
output_nodes = [
|
||||
n for n in model.graph.node if node.output[0] in n.input]
|
||||
for output_node in output_nodes:
|
||||
for i in range(len(output_node.input)):
|
||||
if output_node.input[i] == node.output[0]:
|
||||
output_node.input[i] = node.input[0]
|
||||
|
||||
def find_all_connected_nodes(model, node):
|
||||
nodes0, inputs = find_all_input_nodes(model, node)
|
||||
nodes1, outputs = find_all_output_nodes(model, node)
|
||||
|
|
@ -251,15 +258,36 @@ def get_index(node_list, node):
|
|||
found = [i for i, n in enumerate(node_list) if n == node]
|
||||
return found[0] if found else None
|
||||
|
||||
def get_identity_index_for_deleting(node_list, node):
|
||||
for i, n in enumerate(node_list):
|
||||
# The node's input name has been changed during send/recv insertion,
|
||||
# but it is sufficient to just compare the type and outputs.
|
||||
if (n.op_type == 'Identity' and n.output == node.output):
|
||||
return i
|
||||
return None
|
||||
|
||||
# traverse the graph, group connected nodes and generate subgraph
|
||||
|
||||
|
||||
def generate_subgraph(model, start_nodes):
|
||||
def generate_subgraph(model, start_nodes, identity_node_list):
|
||||
subgraphs = []
|
||||
|
||||
main_graph = onnx.ModelProto()
|
||||
main_graph.CopyFrom(model)
|
||||
|
||||
# remove added identity node before copy to subgraph
|
||||
identity_node_index = []
|
||||
for n in identity_node_list:
|
||||
identity_node_index.append(get_identity_index_for_deleting(main_graph.graph.node, n))
|
||||
identity_node_index.sort(reverse=True)
|
||||
|
||||
for i in reversed(range(len(main_graph.graph.node))):
|
||||
try:
|
||||
if i in identity_node_index:
|
||||
del main_graph.graph.node[i]
|
||||
except:
|
||||
print("error deleting identity node", i)
|
||||
|
||||
all_visited_nodes = []
|
||||
model_count = len(start_nodes)
|
||||
for start in reversed(start_nodes):
|
||||
|
|
@ -362,29 +390,8 @@ def main():
|
|||
output_model_names = [os.path.splitext(input_model_name)[0] + '_' +
|
||||
str(i) + '.onnx' for i in range(stage_count)]
|
||||
|
||||
split_edge_groups = []
|
||||
count = 0
|
||||
updated_edges = {}
|
||||
need_shape_inference = False
|
||||
# Sweep the cut edge to see if there are edges feeding into nodes from two sub-graphs. If so,
|
||||
# insert identity node after those edges with a new ID to distinguish the rest.
|
||||
for cut_input in all_cut_inputs:
|
||||
split_edges = []
|
||||
for i in cut_input:
|
||||
if i.consumingNodes:
|
||||
# if this edge has previously been modified, update its edgeId before inserting new identity
|
||||
if i.edgeId in updated_edges:
|
||||
i.edgeId = updated_edges[i.edgeId]
|
||||
split_edge_groups, new_identity, need_shape_inference = insert_identity(model, all_cut_inputs)
|
||||
|
||||
new_edge_name = 'identity_output_' + str(count)
|
||||
add_identity(model, i, new_edge_name)
|
||||
count += 1
|
||||
split_edges.append(new_edge_name)
|
||||
updated_edges[i.edgeId] = new_edge_name
|
||||
need_shape_inference = True
|
||||
else:
|
||||
split_edges.append(i.edgeId)
|
||||
split_edge_groups.append(split_edges)
|
||||
|
||||
# new edge is being added, need to re-inference shape
|
||||
if need_shape_inference:
|
||||
|
|
@ -392,11 +399,13 @@ def main():
|
|||
|
||||
# after all need-to-be-cut edges identified, split the graph
|
||||
new_sends, new_receives = split_graph(model, split_edge_groups)
|
||||
sub_graphs = generate_subgraph(model, new_receives)
|
||||
remove_identity(model, new_identity)
|
||||
sub_graphs = generate_subgraph(model, new_receives, new_identity)
|
||||
|
||||
for i in range(stage_count):
|
||||
sub_graphs[i] = onnx.shape_inference.infer_shapes(sub_graphs[i])
|
||||
onnx.save(sub_graphs[i], output_model_names[i])
|
||||
print("save to file: ", output_model_names[i])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in a new issue