Merge remote-tracking branch 'origin/ort_training' into edgchen1/merge_from_ort_training

This commit is contained in:
Edward Chen 2020-04-24 00:19:05 +00:00
commit 4416d41874
58 changed files with 1614 additions and 224 deletions

View file

@ -1110,6 +1110,14 @@ class Graph {
// Graph value_info.
std::vector<const NodeArg*> value_info_;
// Strings which have been used as node names.
// New node name should not conflict with this set.
std::unordered_set<std::string> generated_node_names_;
// Strings which have been used as node_arg names.
// New node_arg name should not conflict this this set.
std::unordered_set<std::string> generated_node_arg_names_;
// All node args owned by <*this> graph. Key is node arg name.
std::unordered_map<std::string, std::unique_ptr<NodeArg>> node_args_;

View file

@ -3,7 +3,6 @@
#include "attention.h"
#include "core/framework/tensorprotoutils.h"
#include "core/providers/cuda/cudnn_common.h"
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/shared_inc/fpgeneric.h"
#include "attention_impl.h"

View file

@ -5,7 +5,7 @@
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/providers/cuda/cudnn_common.h"
#include "core/providers/cuda/cuda_common.h"
#include "contrib_ops/cpu/bert/attention.h"
namespace onnxruntime {

View file

@ -2,7 +2,6 @@
// Licensed under the MIT License.
#include "core/providers/common.h"
#include "core/providers/cuda/cudnn_common.h"
#include "core/framework/tensorprotoutils.h"
#include "onnx/defs/tensor_proto_util.h"
#include "contrib_ops/cpu/bert/embed_layer_norm_helper.h"

View file

@ -2,7 +2,7 @@
// Licensed under the MIT License.
#include "core/providers/common.h"
#include "core/providers/cuda/cudnn_common.h"
#include "core/providers/cuda/cuda_common.h"
#include "core/framework/tensorprotoutils.h"
#include "fast_gelu.h"
#include "fast_gelu_impl.h"

View file

@ -2,7 +2,7 @@
// Licensed under the MIT License.
#include "core/providers/common.h"
#include "core/providers/cuda/cudnn_common.h"
#include "core/providers/cuda/cuda_common.h"
#include "core/framework/tensorprotoutils.h"
#include "onnx/defs/tensor_proto_util.h"
#include "skip_layer_norm.h"

View file

@ -5,7 +5,7 @@
#include "layer_norm_impl.h"
#include "core/providers/common.h"
#include "core/providers/cuda/cudnn_common.h"
#include "core/providers/cuda/cuda_common.h"
namespace onnxruntime {
namespace contrib {

View file

@ -107,13 +107,14 @@ __device__ void cuWelfordMuSigma2(
cuWelfordOnlineSum<U>(curr, mu, sigma2, count);
}
// intra-warp reductions
for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x + (1 << l)) & 31;
U muB = WARP_SHFL(mu, srcLaneB);
U countB = WARP_SHFL(count, srcLaneB);
U sigma2B = WARP_SHFL(sigma2, srcLaneB);
#pragma unroll
for (int stride = GPU_WARP_SIZE / 2; stride > 0; stride /= 2) {
U muB = WARP_SHFL_DOWN(mu, stride);
U countB = WARP_SHFL_DOWN(count, stride);
U sigma2B = WARP_SHFL_DOWN(sigma2, stride);
cuChanOnlineSum<U>(muB, sigma2B, countB, mu, sigma2, count);
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if (blockDim.y > 1) {
@ -192,8 +193,8 @@ __device__ void cuWelfordMuSigma2(
for (; l + 7 < n2; l += 8 * numx) {
for (int k = 0; k < 8; k += 2) {
float2 curr = __half22float2(*((__half2*)(lvals + l + k)));
cuWelfordOnlineSum(curr.x, mu, sigma2, count);
cuWelfordOnlineSum(curr.y, mu, sigma2, count);
cuWelfordOnlineSum(static_cast<float>(curr.x), mu, sigma2, count);
cuWelfordOnlineSum(static_cast<float>(curr.y), mu, sigma2, count);
}
}
for (; l < n2; ++l) {
@ -201,13 +202,14 @@ __device__ void cuWelfordMuSigma2(
cuWelfordOnlineSum(curr, mu, sigma2, count);
}
// intra-warp reductions
for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x + (1 << l)) & 31;
float muB = WARP_SHFL(mu, srcLaneB);
float countB = WARP_SHFL(count, srcLaneB);
float sigma2B = WARP_SHFL(sigma2, srcLaneB);
#pragma unroll
for (int stride = GPU_WARP_SIZE / 2; stride > 0; stride /= 2) {
float muB = WARP_SHFL_DOWN(mu, stride);
float countB = WARP_SHFL_DOWN(count, stride);
float sigma2B = WARP_SHFL_DOWN(sigma2, stride);
cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count);
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if (blockDim.y > 1) {
@ -310,7 +312,7 @@ __global__ void cuApplyLayerNorm(
// 1) blockDim.x == GPU_WARP_SIZE
// 2) Tensors are contiguous
//
for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
for (int i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
SharedMemory<U> shared;
U* buf = shared.getPointer();
U mu, sigma2;

View file

@ -5,7 +5,7 @@
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/providers/cuda/cudnn_common.h"
#include "core/providers/cuda/cuda_common.h"
namespace onnxruntime {
namespace contrib {

View file

@ -2376,16 +2376,47 @@ Node& Graph::AddNode(const NodeProto& node_proto,
}
std::string Graph::GenerateNodeArgName(const std::string& base_name) {
// Check if base_name has been used in as any of node_args_' names.
bool found = node_args_.find(base_name) != node_args_.end();
// Check if base_name has been generated by this function.
// If not, add base_name into name set and return the base_name
// as the generated name.
found |= generated_node_arg_names_.find(base_name) != generated_node_arg_names_.end();
if (!found) {
generated_node_arg_names_.insert(base_name);
return base_name;
}
// base_name has been used by another node. Because two node_arg's cannot have
// the sam name, we are going to generate another string.
std::string new_name;
do {
std::ostringstream str;
str << base_name << "_" << name_generator_++;
new_name = str.str();
} while (node_args_.find(new_name) != node_args_.end());
// If node_args_ or generated_node_arg_names_ contains new_name, we go to the next iteration.
} while (node_args_.find(new_name) != node_args_.end() ||
generated_node_arg_names_.find(new_name) != generated_node_arg_names_.end());
// Now new_name is different than any of existing node_arg names.
// Register new_name so that it won't be used again.
generated_node_arg_names_.insert(new_name);
return new_name;
}
std::string Graph::GenerateNodeName(const std::string& base_name) {
// Check if base_name has been used in as any of nodes_' names.
bool found = std::find_if(nodes_.cbegin(), nodes_.cend(), [&base_name](const std::unique_ptr<Node>& n) {
return (n != nullptr) && (n->Name() == base_name);}) != nodes_.end();
// Check if base_name has been generated by this function.
found |= generated_node_names_.find(base_name) != generated_node_names_.end();
if (!found) {
// Register base_name so that it won't be used again.
generated_node_names_.insert(base_name);
return base_name;
}
std::string new_name;
bool keep_going = true;
@ -2394,11 +2425,18 @@ std::string Graph::GenerateNodeName(const std::string& base_name) {
str << base_name << "_" << name_generator_++;
new_name = str.str();
// Check if new_name has been used in as any of nodes_' names.
keep_going = std::find_if(nodes_.cbegin(), nodes_.cend(), [&new_name](const std::unique_ptr<Node>& n) {
return (n != nullptr) && (n->Name() == new_name);
}) != nodes_.end();
// Check if new_name has been generated by this function.
keep_going |= generated_node_names_.find(new_name) != generated_node_names_.end();
} while (keep_going);
// Now new_name is different than any of existing node names.
// Register new_name so that it won't be used again.
generated_node_names_.insert(new_name);
return new_name;
}

View file

@ -0,0 +1,74 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/common/logging/logging.h"
#include "core/framework/tensorprotoutils.h"
#include "core/optimizer/initializer.h"
#include "core/optimizer/rewrite_rule.h"
#include "core/optimizer/expand_elimination.h"
#include "core/graph/graph.h"
#include "core/graph/graph_utils.h"
namespace onnxruntime {
Status ExpandElimination::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
if (graph_utils::RemoveNode(graph, node)) {
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
}
return Status::OK();
}
bool ExpandElimination::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
if (!graph_utils::CanRemoveNode(graph, node, logger)) {
return false;
}
// 1. Check if has input shape.
const auto* input_shape = node.InputDefs()[0]->Shape();
if (input_shape == nullptr) {
return false;
}
// 2. Get target shape if it's constant.
const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name());
if (tensor_proto == nullptr || tensor_proto->dims_size() != 1 || tensor_proto->dims(0) <= 0) {
return false;
}
auto initializer = onnxruntime::make_unique<Initializer>(*tensor_proto, graph.ModelPath());
if (initializer->data_type() != ONNX_NAMESPACE::TensorProto_DataType_INT64) {
return false;
}
const int64_t* target_shapes = initializer->data<int64_t>();
// Check the dimensions starting at the trailing dimension.
int i = input_shape->dim_size() - 1;
int j = static_cast<int>(tensor_proto->dims(0) - 1);
// The Expand produces same input tensor only when target dimension size is not greater than input's.
if (i < j) {
return false;
}
while (i >= 0 && j >= 0) {
auto dim = input_shape->dim(i);
if (utils::HasDimValue(dim)) {
auto dim_value = dim.dim_value();
if (dim_value != target_shapes[j] && target_shapes[j] > 1) {
return false;
}
} else if (target_shapes[j] > 1) {
return false;
}
--i;
--j;
}
return true;
}
} // namespace onnxruntime

View file

@ -0,0 +1,31 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/optimizer/rewrite_rule.h"
namespace onnxruntime {
/**
@Class ExpandElimination
Rewrite rule that eliminates Expand nodes if the node generate the same tensor as the input tensor.
It is attempted to be triggered only on nodes with op type "Expand".
*/
class ExpandElimination : public RewriteRule {
public:
ExpandElimination() noexcept : RewriteRule("ExpandElimination") {}
std::vector<std::string> TargetOpTypes() const noexcept override {
return {"Expand"};
}
private:
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
};
} // namespace onnxruntime

View file

@ -27,6 +27,7 @@
#include "core/optimizer/embed_layer_norm_fusion.h"
#include "core/optimizer/reshape_fusion.h"
#include "core/optimizer/attention_fusion.h"
#include "core/optimizer/expand_elimination.h"
#include "core/optimizer/cast_elimination.h"
#include "core/mlas/inc/mlas.h"
@ -47,6 +48,7 @@ std::vector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(TransformerLevel
rules.push_back(onnxruntime::make_unique<EliminateSlice>());
rules.push_back(onnxruntime::make_unique<UnsqueezeElimination>());
rules.push_back(onnxruntime::make_unique<EliminateDropout>());
rules.push_back(onnxruntime::make_unique<ExpandElimination>());
rules.push_back(onnxruntime::make_unique<CastElimination>());
rules.push_back(onnxruntime::make_unique<FuseReluClip>());
rules.push_back(onnxruntime::make_unique<ShapeToInitializer>());

View file

@ -3,7 +3,7 @@
#include "gemm.h"
#include "core/providers/cpu/math/gemm_helper.h"
#include "core/providers/cuda/cudnn_common.h"
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/shared_inc/fpgeneric.h"
namespace onnxruntime {

View file

@ -17,7 +17,6 @@
// The code below is mostly copied from Pytorch PersistentSoftmax.cuh
#pragma once
#include "core/providers/cuda/cu_inc/common.cuh"
namespace onnxruntime {

View file

@ -4,7 +4,7 @@
#pragma once
#include "gsl/gsl"
#include "core/providers/cuda/cudnn_common.h"
#include "core/providers/cuda/cuda_common.h"
namespace onnxruntime {
namespace cuda {

View file

@ -11,7 +11,7 @@ namespace cuda {
class Concat final : public CudaKernel, public ConcatBase {
public:
Concat(const OpKernelInfo& info) : ConcatBase(info), CudaKernel(info) {}
Concat(const OpKernelInfo& info) : CudaKernel(info), ConcatBase(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};

View file

@ -11,7 +11,7 @@ namespace cuda {
class Gather final : public CudaKernel, public GatherBase {
public:
Gather(const OpKernelInfo& info) : GatherBase(info), CudaKernel(info) {}
Gather(const OpKernelInfo& info) : CudaKernel(info), GatherBase(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};

View file

@ -1,13 +1,17 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/cuda/atomic/common.cuh"
#include "core/providers/cuda/cu_inc/common.cuh"
#include "scatter_elements_impl.h"
#ifdef ENABLE_TRAINING
#include "orttraining/training_ops/cuda/tensor/gather_elements_grad_impl.h"
#endif
namespace onnxruntime {
namespace cuda {
template <typename T, typename Tin, bool OUTERAXIS>
template <typename T, typename Tin, bool OUTERAXIS, typename FuncT>
__global__ void _ScatterElementsKernel2D(
const int max_dim, // max dim on the scattered axis
const T* input_data,
@ -16,7 +20,8 @@ __global__ void _ScatterElementsKernel2D(
const fast_divmod indices_stride_row,
const T* updates,
const int64_t output_row_size,
T* output_data) {
T* output_data,
const FuncT& func) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(indices_index, indices_size);
int row, col, data_idx;
@ -29,12 +34,13 @@ __global__ void _ScatterElementsKernel2D(
} else {
data_idx = row * output_row_size + dim;
}
output_data[data_idx] = updates[indices_index];
func(output_data + data_idx, updates + indices_index);
}
// else invalid index
}
template <typename T, typename Tin>
template <typename T, typename Tin, typename FuncT>
__global__ void _ScatterElementsKernel(
const int rank,
const T* input_data,
@ -46,7 +52,8 @@ __global__ void _ScatterElementsKernel(
const TArray<fast_divmod> indices_strides,
const T* updates,
const int axis,
T* output_data) {
T* output_data,
const FuncT& func) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(indices_index, indices_size);
int dim, remain = indices_index;
size_t data_idx = 0;
@ -61,7 +68,7 @@ __global__ void _ScatterElementsKernel(
}
data_idx += input_strides[i] * dim;
}
output_data[data_idx] = updates[indices_index];
func(output_data + data_idx, updates + indices_index);
}
// From the innermost axis (largest) check equality of dim value of input and indices.
@ -135,7 +142,7 @@ static int CompactInputIndicesDims(
return new_axis;
}
template <typename T, typename Tin>
template <typename T, typename Tin, typename FuncT>
Status ScatterElementsImpl2D(
const T* input_data,
const std::vector<int64_t>& input_dims,
@ -144,23 +151,70 @@ Status ScatterElementsImpl2D(
const std::vector<int64_t>& indices_dims,
const T* updates,
const int axis,
T* output_data) {
T* output_data,
const FuncT& func) {
int blocksPerGrid = gsl::narrow_cast<int>(CeilDiv(indices_size, GridDim::maxThreadsPerBlock));
fast_divmod indices_stride_row(indices_dims[1]);
if (axis == 0) {
_ScatterElementsKernel2D<T, Tin, true><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
_ScatterElementsKernel2D<T, Tin, true, FuncT><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
gsl::narrow_cast<int>(input_dims[0]), input_data,
indices_data, indices_size, indices_stride_row,
updates, input_dims[1], output_data);
updates, input_dims[1], output_data, func);
} else {
_ScatterElementsKernel2D<T, Tin, false><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
_ScatterElementsKernel2D<T, Tin, false, FuncT><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
gsl::narrow_cast<int>(input_dims[1]), input_data,
indices_data, indices_size, indices_stride_row,
updates, input_dims[1], output_data);
updates, input_dims[1], output_data, func);
}
return Status::OK();
}
template <typename T, typename Tin, typename FuncT>
Status ScatterElementsImplInternal(
const int rank,
const T* input_data,
const int64_t input_size,
TArray<int64_t>& buffer_input_dims,
TArray<int64_t>& buffer_input_strides,
const Tin* indices_data,
const int64_t indices_size,
TArray<int64_t>& buffer_indices_dims,
TArray<fast_divmod>& fdm_indices_strides,
const T* updates,
const int axis,
T* output_data,
const FuncT& func) {
if (input_data != output_data) {
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_data, input_data, input_size * sizeof(T), cudaMemcpyDeviceToDevice, 0));
}
if (indices_size > 0) {
std::vector<int64_t> eff_input_dims;
std::vector<int64_t> eff_indices_dims;
int new_axis = CompactInputIndicesDims(
rank, axis, buffer_input_dims.data_, buffer_indices_dims.data_, eff_input_dims, eff_indices_dims);
if (eff_input_dims.size() == 2) {
return ScatterElementsImpl2D(
input_data, eff_input_dims, indices_data, indices_size, eff_indices_dims, updates, new_axis, output_data,
func);
}
int blocksPerGrid = gsl::narrow_cast<int>(CeilDiv(indices_size, GridDim::maxThreadsPerBlock));
_ScatterElementsKernel<T, Tin><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
rank, input_data, buffer_input_dims, buffer_input_strides,
indices_data, indices_size, buffer_indices_dims, fdm_indices_strides,
updates, axis, output_data, func);
}
return Status::OK();
}
template <class T>
struct Func_Assignment {
__device__ __inline__ void operator()(T* a, const T* b) const {
*a = *b;
}
};
template <typename T, typename Tin>
Status ScatterElementsImpl(
const int rank,
@ -175,60 +229,94 @@ Status ScatterElementsImpl(
const T* updates,
const int axis,
T* output_data) {
if (input_data != output_data) {
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_data, input_data, input_size * sizeof(T), cudaMemcpyDeviceToDevice, 0));
}
if (indices_size > 0) {
std::vector<int64_t> eff_input_dims;
std::vector<int64_t> eff_indices_dims;
int new_axis = CompactInputIndicesDims(
rank, axis, buffer_input_dims.data_, buffer_indices_dims.data_, eff_input_dims, eff_indices_dims);
if (eff_input_dims.size() == 2) {
return ScatterElementsImpl2D(
input_data, eff_input_dims, indices_data, indices_size, eff_indices_dims, updates, new_axis, output_data);
}
int blocksPerGrid = gsl::narrow_cast<int>(CeilDiv(indices_size, GridDim::maxThreadsPerBlock));
_ScatterElementsKernel<T, Tin><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
rank, input_data, buffer_input_dims, buffer_input_strides,
indices_data, indices_size, buffer_indices_dims, fdm_indices_strides,
updates, axis, output_data);
}
return Status::OK();
return ScatterElementsImplInternal(rank, input_data, input_size, buffer_input_dims,
buffer_input_strides, indices_data, indices_size, buffer_indices_dims, fdm_indices_strides,
updates, axis, output_data, Func_Assignment<T>());
}
#define SPECIALIZED_TINDEX_IMPL(T, TIndex) \
template Status ScatterElementsImpl<T, TIndex>( \
const int rank, \
const T* input_data, \
const int64_t input_size, \
TArray<int64_t>& buffer_input_dims, \
TArray<int64_t>& buffer_input_strides, \
const TIndex* indices_data, \
const int64_t indices_size, \
TArray<int64_t>& buffer_indices_dims, \
TArray<fast_divmod>& indices_strides, \
const T* updates, \
const int axis, \
#define SCATTER_ELEMENTS_SPECIALIZED_TINDEX_IMPL(T, TIndex) \
template Status ScatterElementsImpl<T, TIndex>( \
const int rank, \
const T* input_data, \
const int64_t input_size, \
TArray<int64_t>& buffer_input_dims, \
TArray<int64_t>& buffer_input_strides, \
const TIndex* indices_data, \
const int64_t indices_size, \
TArray<int64_t>& buffer_indices_dims, \
TArray<fast_divmod>& indices_strides, \
const T* updates, \
const int axis, \
T* output_data)
#define SPECIALIZED_IMPL(T) \
SPECIALIZED_TINDEX_IMPL(T, int32_t); \
SPECIALIZED_TINDEX_IMPL(T, int64_t);
#define SCATTER_ELEMENTS_SPECIALIZED_IMPL(T) \
SCATTER_ELEMENTS_SPECIALIZED_TINDEX_IMPL(T, int32_t); \
SCATTER_ELEMENTS_SPECIALIZED_TINDEX_IMPL(T, int64_t);
SPECIALIZED_IMPL(int8_t)
SPECIALIZED_IMPL(int16_t)
SPECIALIZED_IMPL(int32_t)
SPECIALIZED_IMPL(int64_t)
SPECIALIZED_IMPL(uint8_t)
SPECIALIZED_IMPL(uint16_t)
SPECIALIZED_IMPL(uint32_t)
SPECIALIZED_IMPL(uint64_t)
SPECIALIZED_IMPL(half)
SPECIALIZED_IMPL(float)
SPECIALIZED_IMPL(double)
SPECIALIZED_IMPL(bool)
SCATTER_ELEMENTS_SPECIALIZED_IMPL(int8_t)
SCATTER_ELEMENTS_SPECIALIZED_IMPL(int16_t)
SCATTER_ELEMENTS_SPECIALIZED_IMPL(int32_t)
SCATTER_ELEMENTS_SPECIALIZED_IMPL(int64_t)
SCATTER_ELEMENTS_SPECIALIZED_IMPL(uint8_t)
SCATTER_ELEMENTS_SPECIALIZED_IMPL(uint16_t)
SCATTER_ELEMENTS_SPECIALIZED_IMPL(uint32_t)
SCATTER_ELEMENTS_SPECIALIZED_IMPL(uint64_t)
SCATTER_ELEMENTS_SPECIALIZED_IMPL(half)
SCATTER_ELEMENTS_SPECIALIZED_IMPL(float)
SCATTER_ELEMENTS_SPECIALIZED_IMPL(double)
SCATTER_ELEMENTS_SPECIALIZED_IMPL(bool)
#ifdef ENABLE_TRAINING
template <class T>
struct Func_AtomicAdd {
__device__ __inline__ void operator()(T* a, const T* b) const {
atomic_add(a, *b);
}
};
template <typename T, typename Tin>
Status GatherElementsGradImpl(
const int rank,
TArray<int64_t>& buffer_input_dims,
TArray<int64_t>& buffer_input_strides,
const Tin* indices_data,
const int64_t indices_size,
TArray<int64_t>& buffer_indices_dims,
TArray<fast_divmod>& fdm_indices_strides,
const T* updates,
const int axis,
T* output_data) {
// Give output_data as the input_data parameter by intention,
// to skip input_data copy, which is not applicable for GatherElementsGrad.
return ScatterElementsImplInternal(rank, output_data, 0,
buffer_input_dims, buffer_input_strides, indices_data,
indices_size, buffer_indices_dims, fdm_indices_strides,
updates, axis, output_data, Func_AtomicAdd<T>());
}
#define GATHER_ELEMENTS_GRAD_SPECIALIZED_TINDEX_IMPL(T, TIndex) \
template Status GatherElementsGradImpl<T, TIndex>( \
const int rank, \
TArray<int64_t>& buffer_input_dims, \
TArray<int64_t>& buffer_input_strides, \
const TIndex* indices_data, \
const int64_t indices_size, \
TArray<int64_t>& buffer_indices_dims, \
TArray<fast_divmod>& indices_strides, \
const T* updates, \
const int axis, \
T* output_data)
#define GATHER_ELEMENTS_GRAD_SPECIALIZED_SCATTER_ADD_IMPL(T) \
GATHER_ELEMENTS_GRAD_SPECIALIZED_TINDEX_IMPL(T, int32_t); \
GATHER_ELEMENTS_GRAD_SPECIALIZED_TINDEX_IMPL(T, int64_t);
GATHER_ELEMENTS_GRAD_SPECIALIZED_SCATTER_ADD_IMPL(half)
GATHER_ELEMENTS_GRAD_SPECIALIZED_SCATTER_ADD_IMPL(float)
GATHER_ELEMENTS_GRAD_SPECIALIZED_SCATTER_ADD_IMPL(double)
#endif
} // namespace cuda
} // namespace onnxruntime

View file

@ -26,4 +26,3 @@ Status ScatterElementsImpl(
} // namespace cuda
} // namespace onnxruntime

View file

@ -40,6 +40,7 @@
#include "core/optimizer/reshape_fusion.h"
#include "core/optimizer/attention_fusion.h"
#include "core/optimizer/fast_gelu_fusion.h"
#include "core/optimizer/expand_elimination.h"
#include "core/optimizer/cast_elimination.h"
#include "core/optimizer/utils.h"
#include "core/platform/env.h"
@ -1115,6 +1116,24 @@ TEST_F(GraphTransformationTests, ReshapeFusionInternalReuseTest) {
}
}
TEST_F(GraphTransformationTests, ExpandElimination) {
auto model_uri = MODEL_FOLDER "expand_elimination.onnx";
std::shared_ptr<Model> model;
ASSERT_TRUE(Model::Load(model_uri, model, nullptr, *logger_).IsOK());
Graph& graph = model->MainGraph();
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Expand"] == 6);
auto rule_transformer_L1 = onnxruntime::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
rule_transformer_L1->Register(onnxruntime::make_unique<ExpandElimination>());
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_).IsOK());
op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Expand"] == 3);
}
TEST_F(GraphTransformationTests, CastElimination) {
auto model_uri = MODEL_FOLDER "cast_elimination.onnx";
std::shared_ptr<Model> model;

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View file

@ -0,0 +1,56 @@
import onnx
from onnx import helper
from onnx import TensorProto, GraphProto, OperatorSetIdProto
from onnx import numpy_helper
import numpy as np
X1 = helper.make_tensor_value_info('input1', TensorProto.FLOAT, [2, 1])
X2 = helper.make_tensor_value_info('input2', TensorProto.FLOAT, ['dynamic', 4])
Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 2, 4])
shape_constant1 = numpy_helper.from_array(np.array([1, 4], dtype=np.int64), name='shape_constant1')
shape_constant2 = numpy_helper.from_array(np.array([1, 1], dtype=np.int64), name='shape_constant2')
shape_constant3 = numpy_helper.from_array(np.array([2, 1], dtype=np.int64), name='shape_constant3')
shape_constant4 = numpy_helper.from_array(np.array([1, 1, 1], dtype=np.int64), name='shape_constant4')
shape_constant5 = numpy_helper.from_array(np.array([1, 4], dtype=np.int64), name='shape_constant5')
shape_constant6 = numpy_helper.from_array(np.array([2, 1], dtype=np.int64), name='shape_constant6')
identity1 = helper.make_node('Identity', ['input1'], ['identity1'], name='identity1')
expand1 = helper.make_node('Expand', ['identity1', shape_constant1.name], ['expand1'], name='expand1')
expand2 = helper.make_node('Expand', ['identity1', shape_constant2.name], ['expand2'], name='expand2')
mul1 = helper.make_node('Mul', ['expand1', 'expand2'], ['mul1'], name='mul1') # (2, 4)
expand3 = helper.make_node('Expand', ['mul1', shape_constant3.name], ['expand3'], name='expand3')
expand4 = helper.make_node('Expand', ['identity1', shape_constant4.name], ['expand4'], name='expand4')
mul2 = helper.make_node('Mul', ['expand3', 'expand4'], ['mul2'], name='mul2') # (1, 2, 4)
identity2 = helper.make_node('Identity', ['input2'], ['identity2'], name='identity2')
expand5 = helper.make_node('Expand', ['identity2', shape_constant5.name], ['expand5'], name='expand5')
expand6 = helper.make_node('Expand', ['identity2', shape_constant6.name], ['expand6'], name='expand6')
mul3 = helper.make_node('Mul', ['expand5', 'expand6'], ['mul3'], name='mul3') # (dynamic=2, 4)
mul4 = helper.make_node('Mul', ['mul2', 'mul3'], ['output'], name='mul4')
# Create the graph (GraphProto)
graph_def = helper.make_graph(
[identity1, expand1, expand2, mul1, expand3, expand4, mul2, identity2, expand5, expand6, mul3, mul4],
'expand_elimination_model',
[X1, X2],
[Y],
[shape_constant1, shape_constant2, shape_constant3, shape_constant4, shape_constant5, shape_constant6]
)
opsets = []
onnxdomain = OperatorSetIdProto()
onnxdomain.version = 12
onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
opsets.append(onnxdomain)
msdomain = OperatorSetIdProto()
msdomain.version = 1
msdomain.domain = 'com.microsoft'
opsets.append(msdomain)
kwargs={}
kwargs['opset_imports'] = opsets
# Create the model (ModelProto)
model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs)
onnx.save(model_def, 'expand_elimination.onnx')

View file

@ -573,6 +573,17 @@ IMPLEMENT_GRADIENT_BUILDER(GetGatherGradient) {
SrcNodeAttributes())};
}
IMPLEMENT_GRADIENT_BUILDER(GetGatherElementsGradient) {
return std::vector<NodeDef>{
NodeDef("Shape",
{I(0)},
{IA("x_shape")}),
NodeDef(OpDef{"GatherElementsGrad", kMSDomain, 1},
{GO(0), IA("x_shape"), I(1)},
{GI(0)},
SrcNodeAttributes())};
};
IMPLEMENT_GRADIENT_BUILDER(GetReluGradient) {
return std::vector<NodeDef>{
NodeDef("ReluGrad",
@ -1034,7 +1045,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetSendGradient) {
}
return std::vector<NodeDef>{
NodeDef("Recv",
NodeDef(OpDef{"Recv", kMSDomain, 1},
{O(0), I(1)}, // {Signal, Remote}
out_args,
SrcNodeAttributes())};
@ -1046,18 +1057,38 @@ IMPLEMENT_GRADIENT_BUILDER(GetRecvGradient) {
std::vector<ArgDef> in_args;
in_args.push_back(O(0)); // Signal
in_args.push_back(I(0)); // Remote
in_args.push_back(I(1)); // Remote
for (int i = 1; i < GetSrcNodeOutputSize(); ++i) {
in_args.push_back(GO(i)); // Data
}
return std::vector<NodeDef>{
NodeDef("Send",
NodeDef(OpDef{"Send", kMSDomain, 1},
in_args,
{GI(0)}, // Signal
SrcNodeAttributes())};
}
IMPLEMENT_GRADIENT_BUILDER(GetExpandGradient) {
ArgDef a = I(0), y = O(0);
std::vector<Dimension> a_shape = GetShape(a);
std::vector<Dimension> y_shape = GetShape(y);
std::vector<int64_t> a_axes;
ComputeBroadcastBackwardAxes(a_shape, y_shape, &a_axes, nullptr);
std::vector<NodeDef> output;
if (a_axes.size() > 0) {
HandleBroadcasting(GO(0), a, GI(0), a_axes, output);
} else {
output.push_back(
NodeDef("Identity",
{GO(0)},
{GI(0)}));
}
return output;
}
} // namespace training
} // namespace onnxruntime

View file

@ -45,6 +45,7 @@ DECLARE_GRADIENT_BUILDER(GetGemmGradient)
DECLARE_GRADIENT_BUILDER(GetDropoutGradient)
DECLARE_GRADIENT_BUILDER(GetTrainableDropoutGradient)
DECLARE_GRADIENT_BUILDER(GetGatherNDGradient)
DECLARE_GRADIENT_BUILDER(GetGatherElementsGradient)
DECLARE_GRADIENT_BUILDER(GetGeluGradient)
DECLARE_GRADIENT_BUILDER(GetLayerNormalizationGradient)
DECLARE_GRADIENT_BUILDER(GetBatchNormalizationGradient)
@ -55,6 +56,7 @@ DECLARE_GRADIENT_BUILDER(GetFastGeluGradient)
DECLARE_GRADIENT_BUILDER(GetWhereGradient)
DECLARE_GRADIENT_BUILDER(GetSendGradient)
DECLARE_GRADIENT_BUILDER(GetRecvGradient)
DECLARE_GRADIENT_BUILDER(GetExpandGradient)
} // namespace training
} // namespace onnxruntime

View file

@ -61,6 +61,7 @@ void ComputeBroadcastBackwardAxes(
}
std::vector<Dimension> GetShape(const ArgDef& arg_def) {
ORT_ENFORCE(arg_def.type_proto, "During GetShape, ", arg_def.name, "'s type_proto is null.");
std::vector<Dimension> shape;
const auto& dims = arg_def.type_proto->tensor_type().shape().dim();
for (auto dim = dims.begin(); dim < dims.end(); dim++) {

View file

@ -83,6 +83,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("Dropout", GetDropoutGradient)
REGISTER_GRADIENT_BUILDER("TrainableDropout", GetTrainableDropoutGradient)
REGISTER_GRADIENT_BUILDER("GatherND", GetGatherNDGradient)
REGISTER_GRADIENT_BUILDER("GatherElements", GetGatherElementsGradient)
REGISTER_GRADIENT_BUILDER("Gelu", GetGeluGradient)
REGISTER_GRADIENT_BUILDER("LayerNormalization", GetLayerNormalizationGradient);
REGISTER_GRADIENT_BUILDER("BatchNormalization", GetBatchNormalizationGradient);
@ -93,6 +94,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("Where", GetWhereGradient);
REGISTER_GRADIENT_BUILDER("Send", GetSendGradient);
REGISTER_GRADIENT_BUILDER("Recv", GetRecvGradient);
REGISTER_GRADIENT_BUILDER("Expand", GetExpandGradient);
};
} // namespace training

View file

@ -432,6 +432,43 @@ void RegisterGradientSchemas() {
{"tensor(int32)", "tensor(int64)"},
"Constrain indices to integer types");
ONNX_CONTRIB_OPERATOR_SCHEMA(GatherElementsGrad)
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL)
.SetDoc("GatherElementsGrad")
.Attr(
"axis",
"Which axis to scatter on. Negative value means "
"counting dimensions from the back. Accepted range is [-r, r-1] where r = rank(data).",
AttributeProto::INT,
static_cast<int64_t>(0))
.Input(
0,
"dY",
"Tensor of rank r >=1 (same rank and shape as indices)",
"T")
.Input(1, "shape", "Shape of the GatherElements input data.", "I")
.Input(
2,
"indices",
"Tensor of int32/int64 indices, of r >= 1 (same rank as input). All index values are expected to be "
"within bounds [-s, s-1] along axis of size s. It is an error if any of the index values are out of bounds.",
"Tind")
.Output(0, "dX", "Tensor of rank r >= 1 (same rank as input).", "T")
.TypeConstraint(
"I",
{"tensor(int64)"},
"Constrain input shape to integer tensors.")
.TypeConstraint(
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Input and output types can be of any tensor type.")
.TypeConstraint(
"Tind",
{"tensor(int32)", "tensor(int64)"},
"Constrain indices to integer types");
ONNX_CONTRIB_OPERATOR_SCHEMA(DivGrad)
.SetDomain(kMSDomain)
.SinceVersion(1)
@ -1633,7 +1670,7 @@ Return true if all elements are true and false otherwise.
"Allow inputs and outputs to be any kind of tensor.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
if (ctx.getNumInputs() < ctx.getNumOutputs() + 1)
fail_shape_inference("WaitEvent must have at least (num_outputs + 1) inputs.");
fail_shape_inference("RecordEvent must have at least (num_outputs + 1) inputs.");
// note: if num_input > num_output + 1,
// the additional inputs (idx >= num_ouput + 1) are regarded as dependencies

View file

@ -0,0 +1,279 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "orttraining/core/graph/pipeline_transformer.h"
using namespace onnxruntime::common;
namespace onnxruntime {
namespace training {
void GetPipelineSendOutput(const Graph& graph, std::string& loss_name) {
for (auto& node : graph.Nodes()) {
if (!node.OpType().compare("Send")) {
// send op should always have an output, which is the OutputSignal.
loss_name = node.OutputDefs()[0]->Name();
return;
}
}
}
bool IsBackward(Node& node) {
return (node.Description() == "Backward pass");
}
void AddInputEvent(Graph& graph, const std::string& op_name,
bool is_forward,
std::vector<NodeArg*>& input_args,
std::vector<std::string>& new_input_names) {
ONNX_NAMESPACE::TypeProto event_type_proto;
event_type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
auto event_id_name = graph.GenerateNodeArgName(op_name + (is_forward ? "_fw" : "_bw") + "_event_id");
auto& event_id = graph.GetOrCreateNodeArg(event_id_name, &event_type_proto);
new_input_names.push_back(event_id_name);
input_args.push_back(&event_id);
}
// gradient graph can contain some dangling leaf nodes. Add them all to WaitEvent
// backward node's input.
void FindLeafNodes(Graph& graph, std::vector<NodeArg*>& input_args) {
for (auto& node : graph.Nodes()) {
if (!IsBackward(node)) {
// only check backward node
continue;
}
bool find_consumer_nodes = false;
std::vector<NodeArg*>& outputs = node.MutableOutputDefs();
for (auto& output : outputs) {
std::vector<const Node*> consumer_nodes = graph.GetConsumerNodes(output->Name());
if (consumer_nodes.size() > 0) {
find_consumer_nodes = true;
break;
}
}
if (!find_consumer_nodes && outputs.size() > 0) {
input_args.push_back(outputs[0]);
}
}
};
NodeArg& CreateNodeArg(Graph& graph, const NodeArg* base_arg) {
const auto& new_name = graph.GenerateNodeArgName(base_arg->Name());
ONNX_NAMESPACE::TypeProto type_proto(*(base_arg->TypeAsProto()));
if (graph.GetNodeArg(new_name) != nullptr) {
ORT_THROW("Node with name ", new_name, " already exists.");
}
return graph.GetOrCreateNodeArg(new_name, &type_proto);
}
Status AddRecordBackward(Graph& graph,
Node* send_bw,
std::vector<std::string>& new_input_names,
std::vector<std::string>& new_output_names) {
std::vector<NodeArg*> input_args;
AddInputEvent(graph, "RecordEvent", false /* is_forward */, input_args, new_input_names);
std::vector<NodeArg*> output_args{};
if (send_bw) {
// if we have send op in backward pass (at the end of the graph), we make sure the RecordEvent happens
// after that send by adding Send's outputs to RecordEvent's input list.
input_args.insert(std::end(input_args),
std::begin(send_bw->MutableOutputDefs()),
std::end(send_bw->MutableOutputDefs()));
}
FindLeafNodes(graph, input_args);
// Optimizer will be added after applying pipeline transformer. To support partial graph evaluation,
// the added Record backward op will have its first passthrough input as output.
ORT_RETURN_IF_NOT(input_args.size() >= 2, "RecordEvent backward op at least have two inputs.")
auto& new_output = CreateNodeArg(graph, input_args[1]); // the first input is signal, not passing through
output_args.push_back(&new_output);
new_output_names.push_back(new_output.Name());
graph.AddNode(graph.GenerateNodeName("RecordEvent"),
"RecordEvent",
"Backward pass",
input_args,
output_args,
nullptr,
kMSDomain);
return Status::OK();
}
Status AddWaitForward(Graph& graph, Node* /* recv_fw */, std::vector<std::string>& new_input_names) {
// Append old_input to input_args and return its pass-through value. Note that
// input_args and output_args are Wait's inputs and outputs, respectively.
auto update_wait_input_output = [&](NodeArg* old_input,
std::vector<NodeArg*>& input_args,
std::vector<NodeArg*>& output_args) -> NodeArg& {
input_args.push_back(old_input);
const auto& new_name = graph.GenerateNodeArgName(old_input->Name());
ONNX_NAMESPACE::TypeProto type_proto(*(old_input->TypeAsProto()));
auto& wait_output = graph.GetOrCreateNodeArg(new_name, &type_proto);
output_args.push_back(&wait_output);
return wait_output;
};
std::vector<NodeArg*> input_args;
std::vector<NodeArg*> output_args;
AddInputEvent(graph, "WaitEvent", true /* is_forward */, input_args, new_input_names);
const std::vector<const NodeArg*>& graph_inputs = graph.GetInputsIncludingInitializers();
if (graph_inputs.size() == 0){
ORT_THROW("Graph ", graph.Name(), " doesn't have any inputs.");
}
for (auto& input_arg : graph_inputs) {
NodeArg* mutable_input = graph.GetNodeArg(input_arg->Name());
auto& wait_output = update_wait_input_output(mutable_input, input_args, output_args);
std::vector<Node*> nodes = graph.GetMutableConsumerNodes(input_arg->Name());
for (auto& consumer_node : nodes) {
for (auto& i : consumer_node->MutableInputDefs()) {
if (i->Name() == input_arg->Name()) {
// if the node is fed by input, re-direct it to be fed by WaitEvent's output.
i = &wait_output;
}
}
}
}
graph.AddNode(graph.GenerateNodeName("WaitEvent"),
"WaitEvent",
"",
input_args,
output_args,
nullptr,
kMSDomain);
return Status::OK();
}
Status AddOrSkipRecordForwardWaitBackward(Graph& graph, Node* send_fw, Node* recv_bw, std::vector<std::string>& new_input_names) {
if (!send_fw != !recv_bw){
ORT_THROW("Graph requires either having both send forward node "
"and recv backword node, or none of them. Currently the graph "
"has send forward: ", send_fw, " and recv backward: ", recv_bw);
}
if (!send_fw && !recv_bw){
// Last partition doesn't have send forwrad and recv backward. No insert needed.
return Status::OK();
}
// if we have a send forward op followed by a recv backward op, insert WaitEvent and RecordEvent in between.
Node* record_node = nullptr;
Node* wait_node = nullptr;
// Insert RecordEvent
{
std::vector<NodeArg*> input_args;
std::vector<NodeArg*> output_args;
AddInputEvent(graph, "RecordEvent", true /* is_forward */, input_args, new_input_names);
// Add send forward op's output as record op's input and output
for (auto& output : send_fw->MutableOutputDefs()) {
auto& new_output = CreateNodeArg(graph, output);
output_args.push_back(&new_output);
input_args.push_back(output);
}
auto& new_node = graph.AddNode(graph.GenerateNodeName("RecordEvent"),
"RecordEvent",
"",
input_args,
output_args, /* output */
{}, /* attribute */
kMSDomain);
record_node = &new_node;
}
// Insert WaitEvent
{
std::vector<NodeArg*> input_args;
std::vector<NodeArg*> output_args;
AddInputEvent(graph, "WaitEvent", false /* is_forward */, input_args, new_input_names);
input_args.insert(std::end(input_args),
std::begin(record_node->MutableOutputDefs()),
std::end(record_node->MutableOutputDefs()));
auto& input = recv_bw->MutableInputDefs()[0];
auto& new_output = CreateNodeArg(graph, input);
output_args.push_back(&new_output);
input = &new_output;
auto& new_node = graph.AddNode(graph.GenerateNodeName("WaitEvent"),
"WaitEvent",
"Backward pass",
input_args,
output_args, /* output */
{}, /* attribute */
kMSDomain);
wait_node = &new_node;
ORT_UNUSED_PARAMETER(wait_node);
}
return Status::OK();
}
Status TransformGraphForPipeline(Graph& graph) {
// insert WaitEvent and RecordEvent to the partition
Node* send_fw{nullptr};
Node* send_bw{nullptr};
Node* recv_fw{nullptr};
Node* recv_bw{nullptr};
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Send") {
if (IsBackward(node)) {
send_bw = &node;
} else {
send_fw = &node;
}
} else if (node.OpType() == "Recv") {
if (IsBackward(node)) {
recv_bw = &node;
} else {
recv_fw = &node;
}
}
}
std::vector<std::string> new_input_names;
std::vector<std::string> new_output_names;
ORT_RETURN_IF_ERROR(AddRecordBackward(graph, send_bw, new_input_names, new_output_names));
ORT_RETURN_IF_ERROR(AddWaitForward(graph, recv_fw, new_input_names));
ORT_RETURN_IF_ERROR(AddOrSkipRecordForwardWaitBackward(graph, send_fw, recv_bw, new_input_names));
auto fill_node_args = [&](const Graph& graph,
const std::vector<const NodeArg*>& existed_node_args,
std::vector<std::string>& new_node_arg_names,
std::vector<const NodeArg*>& merged_node_args) {
merged_node_args.insert(merged_node_args.end(), existed_node_args.begin(), existed_node_args.end());
for (auto& name : new_node_arg_names) {
merged_node_args.push_back(graph.GetNodeArg(name));
}
};
const std::vector<const NodeArg*>& graph_inputs = graph.GetInputsIncludingInitializers();
std::vector<const NodeArg*> inputs_args_sets;
inputs_args_sets.reserve(graph_inputs.size() + new_input_names.size());
fill_node_args(graph, graph_inputs, new_input_names, inputs_args_sets);
const std::vector<const NodeArg*>& graph_outputs = graph.GetOutputs();
std::vector<const NodeArg*> outputs_args_sets;
outputs_args_sets.reserve(graph_outputs.size() + new_output_names.size());
fill_node_args(graph, graph_outputs, new_output_names, outputs_args_sets);
graph.SetInputs(inputs_args_sets);
graph.SetOutputs(outputs_args_sets);
graph.SetGraphResolveNeeded();
graph.SetGraphProtoSyncNeeded();
return graph.Resolve();
}
} // namespace training
} // namespace onnxruntime

View file

@ -0,0 +1,15 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/graph/graph.h"
namespace onnxruntime {
namespace training {
void GetPipelineSendOutput(const Graph& graph, std::string& loss_name);
common::Status TransformGraphForPipeline(Graph& graph);
} // namespace training
} // namespace onnxruntime

View file

@ -12,6 +12,7 @@
#include "core/optimizer/conv_add_fusion.h"
#include "core/optimizer/constant_folding.h"
#include "core/optimizer/unsqueeze_elimination.h"
#include "core/optimizer/expand_elimination.h"
#include "core/optimizer/cast_elimination.h"
#include "core/optimizer/rule_based_graph_transformer.h"
#include "core/optimizer/conv_activation_fusion.h"
@ -54,6 +55,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(T
rule_transformer->Register(make_unique<InsertMaxPoolOutput>());
rule_transformer->Register(make_unique<AdjustBatchNormOutputs>());
rule_transformer->Register(make_unique<UnsqueezeElimination>());
rule_transformer->Register(make_unique<ExpandElimination>());
rule_transformer->Register(make_unique<CastElimination>());
rule_transformer->Register(make_unique<InsertSoftmaxCrossEntropyLossOutput>());

View file

@ -15,6 +15,7 @@
#include "core/optimizer/rule_based_graph_transformer.h"
#include "orttraining/core/graph/mixed_precision_transformer.h"
#include "orttraining/core/graph/tensorboard_transformer.h"
#include "orttraining/core/graph/pipeline_transformer.h"
#include "orttraining/core/graph/gradient_builder_base.h"
//Gist Encoding
@ -128,15 +129,23 @@ Status TrainingSession::ConfigureForTraining(
is_mixed_precision_enabled_ = config.mixed_precision_config.has_value();
std::string loss_name{};
const optional<LossFunctionInfo> loss_function_info =
config.loss_function_config.has_value()
? config.loss_function_config.value().loss_function_info
: optional<LossFunctionInfo>{};
optional<std::string> loss_scale_input_name =
is_mixed_precision_enabled_ ? optional<std::string>{""} : optional<std::string>{};
ORT_RETURN_IF_ERROR(ConfigureLossFunction(
config.loss_name, loss_function_info,
loss_scale_input_name.has_value() ? &loss_scale_input_name.value() : nullptr, loss_name));
is_mixed_precision_enabled_ ? optional<std::string>{""} : optional<std::string>{};
if (config.use_pipeline) {
// if use pipeline, first check if model contains send op. If it does, set the
// send node's output as the start tensor to build gradient graph
GetPipelineSendOutput(model_->MainGraph(), loss_name);
}
if (loss_name.empty()) {
const optional<LossFunctionInfo> loss_function_info =
config.loss_function_config.has_value()
? config.loss_function_config.value().loss_function_info
: optional<LossFunctionInfo>{};
ORT_RETURN_IF_ERROR(ConfigureLossFunction(
config.loss_name, loss_function_info,
loss_scale_input_name.has_value() ? &loss_scale_input_name.value() : nullptr, loss_name));
}
ORT_ENFORCE(
!loss_scale_input_name.has_value() || !loss_scale_input_name.value().empty(),
"loss_scale_input_name should not be set to an empty string.");
@ -170,7 +179,6 @@ Status TrainingSession::ConfigureForTraining(
<< weight_names_stream.str();
}
// add gradient graph
ORT_RETURN_IF_ERROR(BuildGradientGraph(
weight_names_to_train, loss_name, config.set_gradients_as_graph_outputs));
@ -182,12 +190,30 @@ Status TrainingSession::ConfigureForTraining(
weight_names_to_train, mixed_precision_config.use_fp16_initializers, fp32_weight_name_to_fp16_node_arg));
}
if (config.use_pipeline) {
ORT_RETURN_IF_ERROR(InsertPipelineOps());
}
// All non-float tensors are not trainable. Remove those weights.
// TODO: this is a temp workaround for removing rank tensor before adding optimizer.
// Re-visit after we port logic for model splitting and hence know the rank tensor name.
for (auto it = weights_to_train_.begin(); it != weights_to_train_.end();) {
const auto* node_arg = model_->MainGraph().GetNodeArg(*it);
ORT_RETURN_IF_NOT(node_arg, "Failed to get NodeArg with name ", *it);
if (node_arg->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
it = weights_to_train_.erase(it);
}
else{
++it;
}
}
// add optimizer or gradient accumulation
if (config.optimizer_config.has_value()) {
OptimizerGraphConfig opt_graph_config{};
std::unordered_map<std::string, OptimizerNodeConfig> opt_node_configs{};
ORT_RETURN_IF_ERROR(SetupOptimizerParams(
weight_names_to_train, fp32_weight_name_to_fp16_node_arg,
weights_to_train_, fp32_weight_name_to_fp16_node_arg,
loss_scale_input_name, config, opt_graph_config, opt_node_configs));
TrainingConfigurationResult::OptimizerConfigurationResult optimizer_config_result{};
@ -198,7 +224,7 @@ Status TrainingSession::ConfigureForTraining(
config_result.opt_config_result = optimizer_config_result;
} else {
if (config.gradient_accumulation_steps > 1) {
ORT_RETURN_IF_ERROR(BuildAccumulationNode(weight_names_to_train));
ORT_RETURN_IF_ERROR(BuildAccumulationNode(weights_to_train_));
}
}
@ -439,6 +465,11 @@ Status TrainingSession::AddTensorboard(const std::string& summary_name,
return DoPostLoadProcessing(*model_);
}
Status TrainingSession::InsertPipelineOps() {
ORT_RETURN_IF_ERROR(TransformGraphForPipeline(model_->MainGraph()));
return DoPostLoadProcessing(*model_);
}
Status TrainingSession::ConfigureLossFunction(
const optional<std::string>& external_loss_name,
const optional<LossFunctionInfo>& loss_function_info,

View file

@ -132,6 +132,9 @@ class TrainingSession : public InferenceSession {
// The optimizer configuration.
// If not provided, no optimizer is added.
optional<OptimizerConfiguration> optimizer_config{};
// Whether to use pipeline in training.
bool use_pipeline{false};
};
/**
@ -262,6 +265,7 @@ class TrainingSession : public InferenceSession {
const std::vector<std::string>& norm_nodes,
const bool dump_convergence_metrics);
common::Status InsertPipelineOps();
common::Status ApplyTransformationsToMainGraph();
/** configure initial transformers for training */

View file

@ -169,6 +169,7 @@ class TrainingRunner {
common::Status ResetLossScaler();
size_t GetRound() const { return round_; }
TrainingSession& GetSession() { return session_; }
private:
Status TrainingLoop(IDataLoader& training_data_loader, IDataLoader* test_data_loader);

View file

@ -510,6 +510,7 @@ def create_ort_training_session_bind_parameters(model, device, world_rank=-1, wo
dtype_torch_to_numpy(torch_params[param].dtype), list(torch_tensor.size()),
torch_tensor.data_ptr())
device_index = get_device_index(device)
create_and_bind_grad_or_grad_accumulate_buffer(train_io_binding, torch_tensor, param, enable_grad_accumulation, device, device_index)
return session, train_io_binding, eval_io_binding, output_name, torch_params, output_types

View file

@ -128,7 +128,7 @@ inline Status GradientChecker<X_T, Y_T, JAC_T>::ComputeTheoreticalJacobianTransp
const size_t dy_size = y_infos[y_idx].shape.Size();
// Compute the theoretical Jacobians one row at a time by back propagating
// '1.0'for each element of 'dy', while holding all other elements of 'dy' at zero.
// '1.0' for each element of 'dy', while holding all other elements of 'dy' at zero.
for (int c = 0; c < dy_size; ++c) { // for each value in the dy input vector
// clear OpTester input/output/initializer
op_session.ClearData();

View file

@ -1599,6 +1599,57 @@ TEST(GradientCheckerTest, GatherNDGrad_int32_indice_unique_float_data_axis_2) {
EXPECT_IS_TINY(max_error);
}
TEST(GradientCheckerTest, GatherElementsGradWithDuplicateUpdate) {
float max_error;
GradientChecker<float, float, float> gradient_checker;
OpDef op_def{"GatherElements", kOnnxDomain, 11};
TensorInfo data_info({3, 3}, true);
TensorInfo indice_info({2, 3}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
std::vector<std::vector<float>> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8, 9}, {1, 2, 0, 2, 0, 0}};
TensorInfo y_info({2, 3}, true);
int64_t axis = 0;
gradient_checker.ComputeGradientError(op_def, {data_info, indice_info}, {y_info}, &max_error, x_datas,
{MakeAttribute("axis", axis)});
EXPECT_IS_TINY(max_error);
}
TEST(GradientCheckerTest, GatherElementsGradWithoutDuplicateUpdate) {
float max_error;
GradientChecker<float, float, float> gradient_checker;
OpDef op_def{"GatherElements", kOnnxDomain, 11};
TensorInfo data_info({3, 3}, true);
TensorInfo indice_info({2, 3}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
std::vector<std::vector<float>> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8, 9}, {1, 1, 1, 2, 2, 2}};
TensorInfo y_info({2, 3}, true);
int64_t axis = 0;
gradient_checker.ComputeGradientError(op_def, {data_info, indice_info}, {y_info}, &max_error, x_datas,
{MakeAttribute("axis", axis)});
EXPECT_IS_TINY(max_error);
}
TEST(GradientCheckerTest, GatherElementsGradAxisWithDuplicateUpdate) {
float max_error;
GradientChecker<float, float, float> gradient_checker;
OpDef op_def{"GatherElements", kOnnxDomain, 11};
TensorInfo data_info({3, 3}, true);
TensorInfo indice_info({2, 3}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
std::vector<std::vector<float>> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8, 9}, {1, 1, 1, 1, 1, 1}};
TensorInfo y_info({2, 3}, true);
int64_t axis = 1;
gradient_checker.ComputeGradientError(op_def, {data_info, indice_info}, {y_info}, &max_error, x_datas,
{MakeAttribute("axis", axis)});
EXPECT_IS_TINY(max_error);
}
TEST(GradientCheckerTest, LayerNormGrad) {
GradientChecker<float, float, float> gradient_checker;
{
@ -1800,6 +1851,84 @@ TEST(Synchronization, WaitAndRecordEventMany) {
}
}
TEST(GradientCheckerTest, ExpandGrad) {
float max_error;
GradientChecker<float, float, float> gradient_checker;
OpDef op_def{"Expand"};
//input_shape = (2, 3, 1), target_shape = (2, 3, 4) ==> shape(result) = (2, 3, 4)
{
TensorInfo x_info({2, 3, 1}, true);
TensorInfo shape_info({3}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
std::vector<std::vector<float>> x_datas = {{1, 2, 3, 4, 5, 6}, {2, 3, 4}};
TensorInfo y_info({2, 3, 4}, true);
gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas);
EXPECT_IS_TINY(max_error);
}
//input_shape = (2, 3, 1), target_shape = (1, 1, 4) ==> shape(result) = (2, 3, 4)
{
TensorInfo x_info({2, 3, 1}, true);
TensorInfo shape_info({3}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
std::vector<std::vector<float>> x_datas = {{1, 2, 3, 4, 5, 6}, {1, 1, 4}};
TensorInfo y_info({2, 3, 4}, true);
gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas);
EXPECT_IS_TINY(max_error);
}
//input_shape = (2, 3, 1), target_shape = (4) ==> shape(result) = (2, 3, 4)
{
TensorInfo x_info({2, 3, 1}, true);
TensorInfo shape_info({1}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
std::vector<std::vector<float>> x_datas = {{1, 2, 3, 4, 5, 6}, {4}};
TensorInfo y_info({2, 3, 4}, true);
gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas);
EXPECT_IS_TINY(max_error);
}
//input_shape = (2, 3, 1), target_shape = (1, 1) ==> shape(result) = (2, 3, 1)
{
TensorInfo x_info({2, 3, 1}, true);
TensorInfo shape_info({2}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
std::vector<std::vector<float>> x_datas = {{1, 2, 3, 4, 5, 6}, {1, 1}};
TensorInfo y_info({2, 3, 1}, true);
gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas);
EXPECT_IS_TINY(max_error);
}
//input_shape = (2, 3), target_shape = (4, 5, 2, 3) ==> shape(result) = (4, 5, 2, 3)
{
TensorInfo x_info({2, 3}, true);
TensorInfo shape_info({4}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
std::vector<std::vector<float>> x_datas = {{1, 2, 3, 4, 5, 6}, {4, 5, 2, 3}};
TensorInfo y_info({4, 5, 2, 3}, true);
gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas);
EXPECT_IS_TINY(max_error);
}
//input_shape = (1, 2, 3), target_shape = (4, 5, 1, 1) ==> shape(result) = (4, 5, 2, 3)
{
TensorInfo x_info({1, 2, 3}, true);
TensorInfo shape_info({4}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
std::vector<std::vector<float>> x_datas = {{1, 2, 3, 4, 5, 6}, {4, 5, 1, 1}};
TensorInfo y_info({4, 5, 2, 3}, true);
gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas);
EXPECT_IS_TINY(max_error);
}
}
} // namespace test
} // namespace onnxruntime

View file

@ -997,6 +997,106 @@ class PipelineBatchPlanner {
}
};
// verify pipeline config can load and gradient graph can construct.
TEST(GradientGraphBuilderTest, TrainingSession_PipelineTransform_base) {
PathString filename_base = ORT_TSTR("testdata/test_training_model_");
auto load_gradient_graph = [](int stageIdx, PathString& input_file, PathString& output_file) {
auto config = MakeBasicTrainingConfig();
config.use_pipeline = true;
PathString backprop_model_file;
ASSERT_STATUS_OK(BuildBackPropGraph(input_file, config, backprop_model_file));
std::shared_ptr<Model> model;
ASSERT_TRUE(Model::Load(backprop_model_file, model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK());
Graph& graph = model->MainGraph();
auto is_backward = [](Node& node) {
return (node.Description() == "Backward pass");
};
// check for wait/record node
Node* wait_fw{nullptr};
Node* wait_bw{nullptr};
Node* record_fw{nullptr};
Node* record_bw{nullptr};
for (auto& node : graph.Nodes()) {
if (node.OpType() == "WaitEvent") {
if (is_backward(node)) {
wait_bw = &node;
} else {
wait_fw = &node;
}
} else if (node.OpType() == "RecordEvent") {
if (is_backward(node)) {
record_bw = &node;
} else {
record_fw = &node;
}
}
}
// every partition should have wait forward and record backward
ASSERT_TRUE(wait_fw && record_bw);
if (stageIdx == 2) {
// the last partition can perform back prop right away. It won't have record
// forward and wait backward
ASSERT_TRUE(!record_fw && !wait_bw);
} else {
ASSERT_TRUE(record_fw && wait_bw);
}
// check for send/recv node
Node* send_fw{nullptr};
Node* send_bw{nullptr};
Node* recv_fw{nullptr};
Node* recv_bw{nullptr};
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Send") {
if (is_backward(node)) {
send_bw = &node;
} else {
send_fw = &node;
}
} else if (node.OpType() == "Recv") {
if (is_backward(node)) {
recv_bw = &node;
} else {
recv_fw = &node;
}
}
}
// except the last partion, each partition should have send forward and recv backward
if (stageIdx == 0 || stageIdx == 1) {
ASSERT_TRUE(send_fw && recv_bw);
} else {
ASSERT_TRUE(!send_fw && !recv_bw);
}
// except the first partion, each partition should have recv forward and send backward
if (stageIdx == 1 || stageIdx == 2) {
ASSERT_TRUE(recv_fw && send_bw);
} else {
ASSERT_TRUE(!recv_fw && !send_bw);
}
auto mp = model->ToProto();
std::ofstream ofs(output_file, std::ofstream::binary);
mp.SerializeToOstream(&ofs);
ofs.close();
};
for (int i = 0; i < 3; ++i) {
#ifdef _WIN32
auto surfix = std::to_wstring(i);
#else
auto surfix = std::to_string(i);
#endif
PathString input_file = filename_base + surfix + ORT_TSTR(".onnx");
PathString output_file = filename_base + surfix + ORT_TSTR("_back.onnx");
load_gradient_graph(i, input_file, output_file);
}
}
TEST(GradientGraphBuilderTest, TrainingSession_WithPipeline) {
auto config = MakeBasicTrainingConfig();
//config.set_gradients_as_graph_outputs = true;

View file

@ -0,0 +1,227 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"
namespace onnxruntime {
namespace cuda {
namespace test {
TEST(GatherElementsGrad, WithoutAxis) {
onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain);
test.AddInput<float>("dY", {2, 3},
{1.0f, 1.1f, 1.2f,
2.0f, 2.1f, 2.2f});
std::vector<int64_t> data_shape = {3, 3};
test.AddInput<int64_t>("data_shape", {2}, data_shape);
test.AddInput<int64_t>("indices", {2, 3},
{1, 0, 2,
0, 2, 1});
test.AddOutput<float>("dX", {3, 3},
{2.0f, 1.1f, 0.0f,
1.0f, 0.0f, 2.2f,
0.0f, 2.1f, 1.2f});
test.Run();
}
TEST(GatherElementsGrad, WithAxis) {
onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain);
test.AddAttribute<int64_t>("axis", 1);
test.AddInput<float>("dY", {1, 2}, {1.1f, 2.1f});
std::vector<int64_t> data_shape = {1, 5};
test.AddInput<int64_t>("data_shape", {2}, data_shape);
test.AddInput<int64_t>("indices", {1, 2}, {1, 3});
test.AddOutput<float>("dX", {1, 5}, {0.0f, 1.1f, 0.0f, 2.1f, 0.0f});
test.Run();
}
TEST(GatherElementsGrad, ThreeDimsWithAxis_0) {
onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain);
test.AddAttribute<int64_t>("axis", 0);
test.AddInput<float>("dY", {1, 3, 3},
{11.0f, 12.0f, 13.0f,
14.0f, 15.0f, 16.0f,
17.0f, 18.0f, 19.0f});
std::vector<int64_t> data_shape = {1, 3, 3};
test.AddInput<int64_t>("data_shape", {3}, data_shape);
// Because axis 0 is only 1 dimension it should be all zeros
test.AddInput<int64_t>("indices", {1, 3, 3},
{0, 0, 0,
0, 0, 0,
0, 0, 0});
test.AddOutput<float>("dX", {1, 3, 3},
{11.0f, 12.0f, 13.0f,
14.0f, 15.0f, 16.0f,
17.0f, 18.0f, 19.0f});
test.Run();
}
TEST(GatherElementsGrad, ThreeDimsWithAxis_2) {
onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain);
test.AddAttribute<int64_t>("axis", 2);
test.AddInput<float>("dY", {1, 3, 3},
{11, 12, 13,
14, 15, 16,
17, 18, 19});
std::vector<int64_t> data_shape = {1, 3, 3};
test.AddInput<int64_t>("data_shape", {3}, data_shape);
test.AddInput<int64_t>("indices", {1, 3, 3},
{2, 1, 0,
2, 1, 0,
2, 1, 0});
test.AddOutput<float>("dX", {1, 3, 3},
{13, 12, 11,
16, 15, 14,
19, 18, 17});
test.Run();
}
TEST(GatherElementsGrad, NegativeAxis) {
onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain);
test.AddAttribute<int64_t>("axis", -1);
test.AddInput<float>("dY", {1, 2}, {1.1f, 2.1f});
std::vector<int64_t> data_shape = {1, 5};
test.AddInput<int64_t>("data_shape", {2}, data_shape);
test.AddInput<int64_t>("indices", {1, 2}, {1, 3});
test.AddOutput<float>("dX", {1, 5}, {0.0f, 1.1f, 0.0f, 2.1f, 0.0f});
test.Run();
}
TEST(GatherElementsGrad, IndicesUpdatesDontMatch) {
onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain);
test.AddAttribute<int64_t>("axis", 1);
test.AddInput<float>("dY", {1, 2}, {1.1f, 2.1f});
std::vector<int64_t> data_shape = {1, 5};
test.AddInput<int64_t>("data_shape", {2}, data_shape);
test.AddInput<int64_t>("indices", {1, 3}, {1, 3, 3});
test.AddOutput<float>("dX", {1, 5}, {1.0f, 3.1f, 3.0f, 6.1f, 5.0f});
test.Run(onnxruntime::test::OpTester::ExpectResult::kExpectFailure, "Indices vs dY dimensions differs at position=1 3 vs 2");
}
TEST(GatherElementsGrad, ValidAxis) {
onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain);
test.AddAttribute<int64_t>("axis", 0);
test.AddInput<float>("dY", {1, 1, 1}, {5.0f});
std::vector<int64_t> data_shape = {4, 2, 1};
test.AddInput<int64_t>("data_shape", {3}, data_shape);
test.AddInput<int64_t>("indices", {1, 1, 1}, {3});
test.AddOutput<float>("dX", {4, 2, 1}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 5.0f, 0.0f});
test.Run();
}
TEST(GatherElementsGrad, ValidNegativeIndex) {
onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain);
test.AddAttribute<int64_t>("axis", 0);
test.AddInput<float>("dY", {1, 1, 1}, {5.0f});
std::vector<int64_t> data_shape = {4, 2, 1};
test.AddInput<int64_t>("data_shape", {3}, data_shape);
test.AddInput<int64_t>("indices", {1, 1, 1}, {-1});
test.AddOutput<float>("dX", {4, 2, 1}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 5.0f, 0.0f});
test.Run();
}
TEST(GatherElementsGrad, SameUpdateWithoutAxis) {
onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain);
test.AddInput<float>("dY", {2, 2},
{11.0f, 22.0f,
33.0f, 44.0f});
std::vector<int64_t> data_shape = {3, 3};
test.AddInput<int64_t>("data_shape", {2}, data_shape);
test.AddInput<int32_t>("indices", {2, 2},
{1, 1,
1, 1},
true);
test.AddOutput<float>("dX", {3, 3},
{0.0f, 0.0f, 0.0f,
44.0f, 66.0f, 0.0f,
0.0f, 0.0f, 0.0f});
test.Run();
}
TEST(GatherElementsGrad, SameUpdateWithAxis) {
onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain);
test.AddAttribute<int64_t>("axis", 1);
test.AddInput<float>("dY", {2, 3},
{11.0f, 22.0f, 33.0f,
44.0f, 55.0f, 66.0f});
std::vector<int64_t> data_shape = {3, 3};
test.AddInput<int64_t>("data_shape", {2}, data_shape);
test.AddInput<int32_t>("indices", {2, 3},
{1, 1, 1,
1, 1, 1},
true);
test.AddOutput<float>("dX", {3, 3},
{0.0f, 66.0f, 0.0f,
0.0f, 165.0f, 0.0f,
0.0f, 0.0f, 0.0f});
test.Run();
}
TEST(GatherElementsGrad, SameUpdateWithNegativeAxis) {
onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain);
test.AddAttribute<int64_t>("axis", -1);
test.AddInput<float>("dY", {2, 3},
{11.0f, 22.0f, 33.0f,
44.0f, 55.0f, 66.0f});
std::vector<int64_t> data_shape = {3, 3};
test.AddInput<int64_t>("data_shape", {2}, data_shape);
test.AddInput<int32_t>("indices", {2, 3},
{1, 0, 1,
1, 0, 1},
true);
test.AddOutput<float>("dX", {3, 3},
{22.0f, 44.0f, 0.0f,
55.0f, 110.0f, 0.0f,
0.0f, 0.0f, 0.0f});
test.Run();
}
TEST(GatherElementsGrad, SameUpdateWithoutAxisMLFloat16) {
onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain);
std::vector<float> update = {11.0f, 22.0f,
33.0f, 44.0f};
std::vector<MLFloat16> fp16_update(update.size());
onnxruntime::test::ConvertFloatToMLFloat16(update.data(), fp16_update.data(), static_cast<int>(update.size()));
std::vector<float> output = {0.0f, 0.0f, 0.0f,
44.0f, 66.0f, 0.0f,
0.0f, 0.0f, 0.0f};
std::vector<MLFloat16> fp16_output(output.size());
onnxruntime::test::ConvertFloatToMLFloat16(output.data(), fp16_output.data(), static_cast<int>(output.size()));
test.AddInput<MLFloat16>("dY", {2, 2}, fp16_update);
std::vector<int64_t> data_shape = {3, 3};
test.AddInput<int64_t>("data_shape", {2}, data_shape);
test.AddInput<int32_t>("indices", {2, 2},
{1, 1,
1, 1},
true);
test.AddOutput<MLFloat16>("dX", {3, 3}, fp16_output);
test.Run();
}
} // namespace test
} // namespace cuda
} // namespace onnxruntime

View file

@ -6,7 +6,6 @@
#pragma once
#include "core/common/common.h"
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/cudnn_common.h"
namespace onnxruntime {
namespace cuda {

View file

@ -6,7 +6,7 @@
#pragma once
#include "core/common/common.h"
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/cudnn_common.h"
namespace onnxruntime {
namespace cuda {

View file

@ -4,7 +4,7 @@
#pragma once
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/providers/cuda/cudnn_common.h"
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/multi_tensor/common.cuh"
constexpr int PARALLEL_LOADS = 4;

View file

@ -4,7 +4,7 @@
#pragma once
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/providers/cuda/cudnn_common.h"
#include "core/providers/cuda/cuda_common.h"
namespace onnxruntime {
namespace cuda {

View file

@ -1,7 +1,6 @@
#pragma once
#include "core/common/common.h"
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/cudnn_common.h"
namespace onnxruntime {
namespace cuda {

View file

@ -190,8 +190,8 @@ __device__ void cuWelfordMuSigma2(
for (; l + 7 < n2; l += 8 * numx) {
for (int k = 0; k < 8; k += 2) {
float2 curr = __half22float2(*((__half2*)(lvals + l + k)));
cuWelfordOnlineSum(curr.x, mu, sigma2, count);
cuWelfordOnlineSum(curr.y, mu, sigma2, count);
cuWelfordOnlineSum(static_cast<float>(curr.x), mu, sigma2, count);
cuWelfordOnlineSum(static_cast<float>(curr.y), mu, sigma2, count);
}
}
for (; l < n2; ++l) {
@ -308,7 +308,7 @@ __global__ void cuApplyLayerNorm(
// 1) blockDim.x == GPU_WARP_SIZE
// 2) Tensors are contiguous
//
for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
for (int i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
SharedMemory<U> shared;
U* buf = shared.getPointer();
U mu, sigma2;
@ -576,7 +576,7 @@ __global__ void cuComputeGradInput(
const U* __restrict__ invvar,
const T* gamma,
T* grad_input) {
for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
for (int i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
U sum_loss1 = U(0);
U sum_loss2 = U(0);
const U c_mean = mean[i1];

View file

@ -4,7 +4,6 @@
#pragma once
#include "core/common/common.h"
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/cudnn_common.h"
namespace onnxruntime {
namespace cuda {

View file

@ -4,7 +4,6 @@
#pragma once
#include "core/common/common.h"
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/cudnn_common.h"
namespace onnxruntime {
namespace cuda {

View file

@ -4,7 +4,6 @@
#pragma once
#include "core/common/common.h"
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/cudnn_common.h"
#include "core/providers/cuda/multi_tensor/common.cuh"
namespace onnxruntime {

View file

@ -4,7 +4,6 @@
#pragma once
#include "core/common/common.h"
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/cudnn_common.h"
namespace onnxruntime {
namespace cuda {

View file

@ -0,0 +1,137 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "orttraining/training_ops/cuda/tensor/gather_elements_grad.h"
#include "orttraining/training_ops/cuda/tensor/gather_elements_grad_impl.h"
#include "core/providers/cpu/tensor/utils.h"
#include "core/providers/common.h"
namespace onnxruntime {
namespace cuda {
ONNX_OPERATOR_KERNEL_EX(
GatherElementsGrad,
kMSDomain,
1,
kCudaExecutionProvider,
KernelDefBuilder()
.InputMemoryType<OrtMemTypeCPUInput>(1) // 'GatherElements' data shape needs to be on CPU
.TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes())
.TypeConstraint("I", DataTypeImpl::GetTensorType<int64_t>())
.TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()}),
GatherElementsGrad);
template <typename T>
struct GatherElementsGrad::ComputeImpl {
Status operator()(const Tensor* dY,
const Tensor* indices_tensor,
Tensor* dX,
const int rank,
TArray<int64_t>& buffer_output_dims,
TArray<int64_t>& buffer_input_strides,
const int64_t indices_size,
TArray<int64_t>& buffer_indices_dims,
TArray<fast_divmod>& fdm_indices_strides,
const int axis) const {
T* output_data = dX->template MutableData<T>();
const T* update_data = dY->template Data<T>();
typedef typename ToCudaType<T>::MappedType CudaT;
MLDataType Tin_type = indices_tensor->DataType();
if (utils::IsPrimitiveDataType<int32_t>(Tin_type)) {
const int32_t* indices_data = indices_tensor->template Data<int32_t>();
return GatherElementsGradImpl(
rank,
buffer_output_dims,
buffer_input_strides,
indices_data,
indices_size,
buffer_indices_dims,
fdm_indices_strides,
reinterpret_cast<const CudaT*>(update_data),
axis,
reinterpret_cast<CudaT*>(output_data));
} else if (utils::IsPrimitiveDataType<int64_t>(Tin_type)) {
const int64_t* indices_data = indices_tensor->template Data<int64_t>();
return GatherElementsGradImpl(
rank,
buffer_output_dims,
buffer_input_strides,
indices_data,
indices_size,
buffer_indices_dims,
fdm_indices_strides,
reinterpret_cast<const CudaT*>(update_data),
axis,
reinterpret_cast<CudaT*>(output_data));
}
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Type for Tin is not supported yet in ScatterElements.");
}
};
Status GatherElementsGrad::ComputeInternal(OpKernelContext* context) const {
const auto* dY = context->Input<Tensor>(0);
const Tensor* shape = context->Input<Tensor>(1);
const TensorShape data_shape(shape->template Data<int64_t>(), shape->Shape().Size());
const int axis = static_cast<int>(HandleNegativeAxis(axis_, data_shape.NumDimensions()));
const auto* indices_tensor = context->Input<Tensor>(2);
const auto& indices_dims = indices_tensor->Shape().GetDims();
const int64_t indices_size = indices_tensor->Shape().Size();
const auto& dY_dims = dY->Shape().GetDims();
if (indices_dims.size() != dY_dims.size()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Indices and dY must have the same rank");
}
for (size_t i = 0; i < indices_dims.size(); ++i) {
if (indices_dims[i] != dY_dims[i]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Indices vs dY dimensions differs at position=", i,
" ", indices_dims[i], " vs ", dY_dims[i]);
}
}
// According to the spec the rank of ind/upd shall be the same as output(data)
// and we also want to make sure that the dimensions of the of the ind/upd do not
// exceed that of the output
const auto& output_dims = data_shape.GetDims();
if (output_dims.size() != indices_dims.size()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Indices must have the same rank as Output. Indices rank=",
indices_dims.size(), ". Output rank=", output_dims.size());
}
for (size_t i = 0; i < output_dims.size(); ++i) {
if (output_dims[i] < indices_dims[i]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Indices dim=", indices_dims[i], " at pos=", i,
" is greater than Output dim=", output_dims[i]);
}
}
int rank = static_cast<int>(output_dims.size());
Tensor* dX = context->Output(0, data_shape);
CUDA_RETURN_IF_ERROR(cudaMemset(dX->MutableDataRaw(), 0, dX->SizeInBytes()));
TArray<int64_t> buffer_output_dims(output_dims);
TensorPitches input_strides(output_dims);
TArray<int64_t> buffer_input_strides(input_strides);
TArray<int64_t> buffer_indices_dims(indices_dims);
TArray<fast_divmod> fdm_indices_strides(rank);
TensorPitches indices_strides(indices_dims);
for (auto i = 0; i < rank; i++) {
fdm_indices_strides[i] = fast_divmod(static_cast<int>(indices_strides[i]));
}
utils::MLTypeCallDispatcherRet<Status, ComputeImpl, MLFloat16, float, double>
t_disp(dY->GetElementType());
return t_disp.Invoke(dY, indices_tensor, dX, rank,
buffer_output_dims, buffer_input_strides, indices_size,
buffer_indices_dims, fdm_indices_strides, axis);
}
} // namespace cuda
} // namespace onnxruntime

View file

@ -0,0 +1,28 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/providers/cuda/cuda_common.h"
namespace onnxruntime {
namespace cuda {
class GatherElementsGrad final : public CudaKernel {
public:
GatherElementsGrad(const OpKernelInfo& info) : CudaKernel(info) {
info.GetAttrOrDefault("axis", &axis_, static_cast<int64_t>(0));
}
~GatherElementsGrad() = default;
Status ComputeInternal(OpKernelContext* context) const override;
private:
template <typename T>
struct ComputeImpl;
int64_t axis_;
};
} // namespace cuda
} // namespace onnxruntime

View file

@ -0,0 +1,26 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <stdint.h>
#include "core/providers/cuda/shared_inc/cuda_utils.h"
namespace onnxruntime {
namespace cuda {
template <typename T, typename Tin>
Status GatherElementsGradImpl(
const int rank,
TArray<int64_t>& buffer_input_dims,
TArray<int64_t>& buffer_input_strides,
const Tin* indices_data,
const int64_t indices_size,
TArray<int64_t>& buffer_indices_dims,
TArray<fast_divmod>& indices_strides,
const T* updates,
const int axis,
T* output_data);
} // namespace cuda
} // namespace onnxruntime

View file

@ -121,6 +121,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_float, LayerNormalizationGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_float, LayerNormalizationGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, SliceGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GatherElementsGrad);
#ifdef USE_HOROVOD
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, HorovodAllReduce);
@ -252,6 +253,7 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_float, LayerNormalizationGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_float, LayerNormalizationGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, SliceGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GatherElementsGrad)>,
#ifdef USE_HOROVOD
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, HorovodAllReduce)>,

View file

@ -17,6 +17,7 @@ import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
import os
# TODO: remove after ready for CV
# import sys
@ -27,13 +28,15 @@ import numpy as np
# from ort_trainer import IODescription, ModelDescription, ORTTrainer, ORTModel
from onnxruntime.capi.ort_trainer import IODescription, ModelDescription, ORTTrainer, ORTModel
from mpi4py import MPI
from onnxruntime.capi._pybind_state import set_cuda_device_id
class NeuralNet(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(NeuralNet, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes)
self.fc2 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
out = self.fc1(x)
@ -79,18 +82,19 @@ def test_with_model(args, model, device, test_loader, optimizer, epoch):
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
def train_with_trainer(args, trainer, device, train_loader, epoch):
def train_with_trainer(args, trainer, device, train_loader, epoch):
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
data = data.reshape(data.shape[0], -1)
learning_rate = torch.tensor([args.lr])
loss = trainer.train_step((data, target, learning_rate))
loss = trainer.train_step(data, target, learning_rate)
# Since the output corresponds to [loss_desc, probability_desc], the first value is taken as loss.
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
100. * batch_idx / len(train_loader), loss[0]))
# TODO: comple this once ORT training can do evaluation.
def test_with_trainer(args, trainer, device, test_loader):
@ -152,8 +156,6 @@ def main():
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 0, 'pin_memory': True}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
@ -167,6 +169,19 @@ def main():
transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)
comm = MPI.COMM_WORLD
args.local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) if ('OMPI_COMM_WORLD_LOCAL_RANK' in os.environ) else 0
args.world_rank = int(os.environ['OMPI_COMM_WORLD_RANK']) if ('OMPI_COMM_WORLD_RANK' in os.environ) else 0
args.world_size=comm.Get_size()
torch.cuda.set_device(args.local_rank)
if use_cuda:
device = torch.device("cuda", args.local_rank)
else:
device = torch.device("cpu")
args.n_gpu = 1
set_cuda_device_id(args.local_rank)
input_size = 784
hidden_size = 500
num_classes = 10
@ -175,14 +190,17 @@ def main():
model_desc = mnist_model_description()
if args.use_ort_trainer:
# use log_interval as gradient accumulate steps
trainer = ORTTrainer(model, my_loss, model_desc, "SGDOptimizer", None, IODescription('Learning_Rate', [1,], torch.float32), device)
trainer = ORTTrainer(model, my_loss, model_desc, "LambOptimizer", None, IODescription('Learning_Rate', [1,], torch.float32), device, 1, None,
args.world_rank, args.world_size, use_mixed_precision=False, allreduce_post_accumulation = True)
print('\nBuild ort model done.')
for epoch in range(1, args.epochs + 1):
train_with_trainer(args, trainer, device, train_loader, epoch)
import pdb
test_with_trainer(args, trainer, device, test_loader)
else:
model = ORTModel(model, my_loss, model_desc, device)
model = ORTModel(model, my_loss, model_desc, device, None, args.world_rank, args.world_size)
print('\nBuild ort model done.')
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

View file

@ -4,7 +4,6 @@ import onnx
from onnx import helper
from onnx import TensorProto
from onnx import OperatorSetIdProto
# Edge that needs to be cut for the split.
# If the edge is feeding into more than one nodes, and not all the nodes belong to the same cut,
# specify those consuming nodes that need to be cut
@ -29,31 +28,6 @@ def split_graph(model, split_edge_groups):
new_send_nodes = []
new_recv_nodes = []
# Add wait for initial inputs. This needs to be done first before new inputs
# are introduced from split
initializer_lists = [a.name for a in model.graph.initializer]
input_tensors = [
value.name for value in model.graph.input if value.name not in initializer_lists]
input_wait_signal = model.graph.input.add()
input_wait_signal.CopyFrom(helper.make_tensor_value_info(
'input_wait_signal', onnx.TensorProto.INT64, None))
input_wait = model.graph.node.add()
input_wait.CopyFrom(helper.make_node(
'WaitEvent',
inputs=['input_wait_signal'],
outputs=[],
domain=ms_domain))
for i in input_tensors:
for node in model.graph.node:
for j in range(len(node.input)):
if node.input[j] == i:
node.input[j] = i + '_sync'
input_wait.input.extend(input_tensors)
input_wait.output.extend([i + '_sync' for i in input_tensors])
for cut_index in range(len(split_edge_groups)):
edgeIds = split_edge_groups[cut_index]
@ -62,7 +36,7 @@ def split_graph(model, split_edge_groups):
upstream_nodes = []
upstream_nodes_output_index = []
output_shapes = []
element_types = []
for id in edgeIds:
for node in model.graph.node:
if len(node.output) >= 1:
@ -70,25 +44,43 @@ def split_graph(model, split_edge_groups):
if j == id:
upstream_nodes.append(node)
upstream_nodes_output_index.append(i)
for info in model.graph.value_info:
if info.name == id:
output_shapes.append(info.type)
record_signal = model.graph.input.add()
record_signal.CopyFrom(helper.make_tensor_value_info(
'record_input_signal' + str(cut_index), onnx.TensorProto.INT64, None))
wait_signal = model.graph.input.add()
wait_signal.CopyFrom(helper.make_tensor_value_info(
'wait_input_signal' + str(cut_index), onnx.TensorProto.INT64, None))
# assuming all tensors are of type float
element_types.append(1)
for info in model.graph.value_info:
if info.name == id:
output_shapes.append(info.type)
send_input_signal_name = 'send_input_signal' + str(cut_index)
send_signal = model.graph.input.add()
send_signal.CopyFrom(helper.make_tensor_value_info(
'send_input_signal' + str(cut_index), onnx.TensorProto.BOOL, None))
send_input_signal_name, onnx.TensorProto.BOOL, None))
send_signal = helper.make_tensor(
send_input_signal_name, TensorProto.BOOL, (), (True,))
model.graph.initializer.extend([send_signal])
recv_input_signal_name = 'recv_input_signal' + str(cut_index)
recv_signal = model.graph.input.add()
recv_signal.CopyFrom(helper.make_tensor_value_info(
'recv_input_signal' + str(cut_index), onnx.TensorProto.BOOL, None))
recv_input_signal_name, onnx.TensorProto.BOOL, None))
recv_signal = helper.make_tensor(
recv_input_signal_name, TensorProto.BOOL, (), (True,))
model.graph.initializer.extend([recv_signal])
send_dst_rank_name = 'send_dst_rank' + str(cut_index)
send_dst_rank = model.graph.input.add()
send_dst_rank.CopyFrom(helper.make_tensor_value_info(
send_dst_rank_name, onnx.TensorProto.INT64, None))
send_dst_rank = helper.make_tensor(
send_dst_rank_name, TensorProto.INT64, (), (cut_index + 1,))
model.graph.initializer.extend([send_dst_rank])
recv_src_rank_name = 'recv_src_rank' + str(cut_index)
recv_src_rank = model.graph.input.add()
recv_src_rank.CopyFrom(helper.make_tensor_value_info(
recv_src_rank_name, onnx.TensorProto.INT64, None))
recv_src_rank = helper.make_tensor(
recv_src_rank_name, TensorProto.INT64, (), (cut_index,))
model.graph.initializer.extend([recv_src_rank])
# output signal from send after cut
send_output_signal = model.graph.output.add()
@ -103,41 +95,23 @@ def split_graph(model, split_edge_groups):
new_send = model.graph.node.add()
new_send.CopyFrom(helper.make_node(
'Send',
inputs=['send_input_signal' + str(cut_index)],
inputs=[send_input_signal_name, send_dst_rank_name],
outputs=['send_output_signal' + str(cut_index)],
tag=0,
src=cut_index,
dst=cut_index + 1,
domain=ms_domain,
element_type=7, # assuming all tensors are of type float
element_types=element_types,
name='send'))
new_receive = model.graph.node.add()
new_receive.CopyFrom(helper.make_node(
'Recv',
inputs=['recv_input_signal' + str(cut_index)],
inputs=[recv_input_signal_name, recv_src_rank_name],
outputs=['receive_output_signal' + str(cut_index)],
tag=1,
src=cut_index,
dst=cut_index + 1,
tag=0,
domain=ms_domain,
element_type=7, # assuming all tensors are of type float
element_types=element_types,
name='receive'))
new_wait = model.graph.node.add()
new_wait.CopyFrom(helper.make_node(
'WaitEvent',
inputs=['wait_input_signal' + str(cut_index)],
outputs=[],
domain=ms_domain))
new_record = model.graph.node.add()
new_record.CopyFrom(helper.make_node(
'RecordEvent',
inputs=['record_input_signal' + str(cut_index)],
outputs=[],
domain=ms_domain))
for i in range(len(upstream_nodes)):
n = upstream_nodes[i]
idx = upstream_nodes_output_index[i]
@ -155,24 +129,16 @@ def split_graph(model, split_edge_groups):
'_recv' + str(cut_index)
add_expand_type(model, new_receive_output_name, output_type)
new_wait_output_name = output_edge_name + '_wait' + str(cut_index)
add_expand_type(model, new_wait_output_name, output_type)
# the order of data flow is: node-output -> record -> send -> recv -> wait -> node-input
new_record.input.extend([output_edge_name])
new_record.output.extend([new_send_input_name])
new_send.input.extend([new_send_input_name])
new_send.input.extend([output_edge_name])
new_receive.output.extend([new_receive_output_name])
new_wait.input.extend([new_receive_output_name])
new_wait.output.extend([new_wait_output_name])
for output_node in output_nodes:
for i in range(len(output_node.input)):
for edgeId in edgeIds:
if output_node.input[i] == edgeId:
output_node.input[i] = new_wait_output_name
output_node.input[i] = new_receive_output_name
new_send_nodes.append(new_send)
new_recv_nodes.append(new_receive)
@ -236,9 +202,50 @@ def add_identity(model, cuttingEdge, newEdgeIdName):
if output_nodes[i].input[j] == edgeId:
output_nodes[i].input[j] = newEdgeIdName
return newEdgeIdName
return new_identity
def insert_identity(model, all_cut_inputs):
count = 0
updated_edges = {}
new_added_identity = []
split_edge_groups = []
need_shape_inference = False
# Sweep the cut edge to see if there are edges feeding into nodes from two sub-graphs. If so,
# insert identity node after those edges with a new ID to distinguish the rest.
for cut_input in all_cut_inputs:
split_edges = []
for i in cut_input:
if i.consumingNodes:
# if this edge has previously been modified, update its edgeId before inserting new identity
if i.edgeId in updated_edges:
i.edgeId = updated_edges[i.edgeId]
new_edge_name = 'identity_output_' + str(count)
new_added_identity.append(
add_identity(model, i, new_edge_name))
count += 1
split_edges.append(new_edge_name)
updated_edges[i.edgeId] = new_edge_name
need_shape_inference = True
else:
split_edges.append(i.edgeId)
split_edge_groups.append(split_edges)
return split_edge_groups, new_added_identity, need_shape_inference
# after the graph is split, remove the added identity node because identity op is not registered in gradient builder.
def remove_identity(model, new_added_identity):
for node in new_added_identity:
assert node.op_type == 'Identity'
output_nodes = [
n for n in model.graph.node if node.output[0] in n.input]
for output_node in output_nodes:
for i in range(len(output_node.input)):
if output_node.input[i] == node.output[0]:
output_node.input[i] = node.input[0]
def find_all_connected_nodes(model, node):
nodes0, inputs = find_all_input_nodes(model, node)
nodes1, outputs = find_all_output_nodes(model, node)
@ -251,15 +258,36 @@ def get_index(node_list, node):
found = [i for i, n in enumerate(node_list) if n == node]
return found[0] if found else None
def get_identity_index_for_deleting(node_list, node):
for i, n in enumerate(node_list):
# The node's input name has been changed during send/recv insertion,
# but it is sufficient to just compare the type and outputs.
if (n.op_type == 'Identity' and n.output == node.output):
return i
return None
# traverse the graph, group connected nodes and generate subgraph
def generate_subgraph(model, start_nodes):
def generate_subgraph(model, start_nodes, identity_node_list):
subgraphs = []
main_graph = onnx.ModelProto()
main_graph.CopyFrom(model)
# remove added identity node before copy to subgraph
identity_node_index = []
for n in identity_node_list:
identity_node_index.append(get_identity_index_for_deleting(main_graph.graph.node, n))
identity_node_index.sort(reverse=True)
for i in reversed(range(len(main_graph.graph.node))):
try:
if i in identity_node_index:
del main_graph.graph.node[i]
except:
print("error deleting identity node", i)
all_visited_nodes = []
model_count = len(start_nodes)
for start in reversed(start_nodes):
@ -362,29 +390,8 @@ def main():
output_model_names = [os.path.splitext(input_model_name)[0] + '_' +
str(i) + '.onnx' for i in range(stage_count)]
split_edge_groups = []
count = 0
updated_edges = {}
need_shape_inference = False
# Sweep the cut edge to see if there are edges feeding into nodes from two sub-graphs. If so,
# insert identity node after those edges with a new ID to distinguish the rest.
for cut_input in all_cut_inputs:
split_edges = []
for i in cut_input:
if i.consumingNodes:
# if this edge has previously been modified, update its edgeId before inserting new identity
if i.edgeId in updated_edges:
i.edgeId = updated_edges[i.edgeId]
split_edge_groups, new_identity, need_shape_inference = insert_identity(model, all_cut_inputs)
new_edge_name = 'identity_output_' + str(count)
add_identity(model, i, new_edge_name)
count += 1
split_edges.append(new_edge_name)
updated_edges[i.edgeId] = new_edge_name
need_shape_inference = True
else:
split_edges.append(i.edgeId)
split_edge_groups.append(split_edges)
# new edge is being added, need to re-inference shape
if need_shape_inference:
@ -392,11 +399,13 @@ def main():
# after all need-to-be-cut edges identified, split the graph
new_sends, new_receives = split_graph(model, split_edge_groups)
sub_graphs = generate_subgraph(model, new_receives)
remove_identity(model, new_identity)
sub_graphs = generate_subgraph(model, new_receives, new_identity)
for i in range(stage_count):
sub_graphs[i] = onnx.shape_inference.infer_shapes(sub_graphs[i])
onnx.save(sub_graphs[i], output_model_names[i])
print("save to file: ", output_model_names[i])
if __name__ == "__main__":