mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-30 03:37:44 +00:00
Add opset 15 kernels for Pow, BatchNorm, and Shape (#8442)
This commit is contained in:
parent
33a97e995b
commit
cee79526fd
25 changed files with 471 additions and 276 deletions
|
|
@ -38,7 +38,8 @@ Do not modify directly.*
|
|||
|AveragePool|*in* X:**T**<br> *out* Y:**T**|11+|**T** = tensor(float)|
|
||||
|||10|**T** = tensor(float)|
|
||||
|||[7, 9]|**T** = tensor(float)|
|
||||
|BatchNormalization|*in* X:**T**<br> *in* scale:**T**<br> *in* B:**T**<br> *in* input_mean:**U**<br> *in* input_var:**U**<br> *out* Y:**T**<br> *out* running_mean:**U**<br> *out* running_var:**U**<br><br>or<br><br>*in* X:**T**<br> *in* scale:**T**<br> *in* B:**T**<br> *in* mean:**T**<br> *in* var:**T**<br> *out* Y:**T**<br> *out* mean:**T**<br> *out* var:**T**<br> *out* saved_mean:**T**<br> *out* saved_var:**T**<br><br>or<br><br>*in* X:**T**<br> *in* scale:**T1**<br> *in* B:**T1**<br> *in* input_mean:**T2**<br> *in* input_var:**T2**<br> *out* Y:**T**<br> *out* running_mean:**T2**<br> *out* running_var:**T2**|14+|**T** = tensor(double), tensor(float)|
|
||||
|BatchNormalization|*in* X:**T**<br> *in* scale:**T**<br> *in* B:**T**<br> *in* input_mean:**U**<br> *in* input_var:**U**<br> *out* Y:**T**<br> *out* running_mean:**U**<br> *out* running_var:**U**<br><br>or<br><br>*in* X:**T**<br> *in* scale:**T**<br> *in* B:**T**<br> *in* mean:**T**<br> *in* var:**T**<br> *out* Y:**T**<br> *out* mean:**T**<br> *out* var:**T**<br> *out* saved_mean:**T**<br> *out* saved_var:**T**<br><br>or<br><br>*in* X:**T**<br> *in* scale:**T1**<br> *in* B:**T1**<br> *in* input_mean:**T2**<br> *in* input_var:**T2**<br> *out* Y:**T**<br> *out* running_mean:**T2**<br> *out* running_var:**T2**|15+|**T** = tensor(double), tensor(float)<br/> **T1** = tensor(double), tensor(float)<br/> **T2** = tensor(double), tensor(float)|
|
||||
|||14|**T** = tensor(double), tensor(float)<br/> **U** = tensor(double), tensor(float)|
|
||||
|||[9, 13]|**T** = tensor(double), tensor(float)|
|
||||
|||[7, 8]|**T** = tensor(double), tensor(float)|
|
||||
|BitShift|*in* X:**T**<br> *in* Y:**T**<br> *out* Z:**T**|11+|**T** = tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|
|
@ -202,7 +203,8 @@ Do not modify directly.*
|
|||
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||[2, 10]|**T** = tensor(double), tensor(float)|
|
||||
|ParametricSoftplus|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|
||||
|Pow|*in* X:**T**<br> *in* Y:**T**<br> *out* Z:**T**<br><br>or<br><br>*in* X:**T**<br> *in* Y:**T1**<br> *out* Z:**T**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)<br/> **T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|
||||
|Pow|*in* X:**T**<br> *in* Y:**T**<br> *out* Z:**T**<br><br>or<br><br>*in* X:**T**<br> *in* Y:**T1**<br> *out* Z:**T**|15+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)<br/> **T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|
||||
|||[13, 14]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)<br/> **T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|
||||
|||12|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)<br/> **T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|
||||
|||[7, 11]|**T** = tensor(double), tensor(float)|
|
||||
|QLinearConv|*in* x:**T1**<br> *in* x_scale:**tensor(float)**<br> *in* x_zero_point:**T1**<br> *in* w:**T2**<br> *in* w_scale:**tensor(float)**<br> *in* w_zero_point:**T2**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T3**<br> *in* B:**T4**<br> *out* y:**T3**|10+|**T1** = tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(uint8)<br/> **T4** = tensor(int32)|
|
||||
|
|
@ -280,7 +282,8 @@ Do not modify directly.*
|
|||
|SequenceErase|*in* input_sequence:**S**<br> *in* position:**I**<br> *out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
|
||||
|SequenceInsert|*in* input_sequence:**S**<br> *in* tensor:**T**<br> *in* position:**I**<br> *out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
|
||||
|SequenceLength|*in* input_sequence:**S**<br> *out* length:**I**|11+|**I** = tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
|
||||
|Shape|*in* data:**T**<br> *out* shape:**T1**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|
||||
|Shape|*in* data:**T**<br> *out* shape:**T1**|15+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|
||||
|||[13, 14]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|
||||
|||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|
||||
|Shrink|*in* input:**T**<br> *out* output:**T**|9+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|Sigmoid|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float)|
|
||||
|
|
@ -446,7 +449,8 @@ Do not modify directly.*
|
|||
|AveragePool|*in* X:**T**<br> *out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|||10|**I** = tensor(int64)<br/> **T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|||[7, 9]|**I** = tensor(int64)<br/> **T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|BatchNormalization|*in* X:**T**<br> *in* scale:**T**<br> *in* B:**T**<br> *in* input_mean:**U**<br> *in* input_var:**U**<br> *out* Y:**T**<br> *out* running_mean:**U**<br> *out* running_var:**U**<br><br>or<br><br>*in* X:**T**<br> *in* scale:**T**<br> *in* B:**T**<br> *in* mean:**T**<br> *in* var:**T**<br> *out* Y:**T**<br> *out* mean:**T**<br> *out* var:**T**<br> *out* saved_mean:**T**<br> *out* saved_var:**T**<br><br>or<br><br>*in* X:**T**<br> *in* scale:**T1**<br> *in* B:**T1**<br> *in* input_mean:**T2**<br> *in* input_var:**T2**<br> *out* Y:**T**<br> *out* running_mean:**T2**<br> *out* running_var:**T2**|14+|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|BatchNormalization|*in* X:**T**<br> *in* scale:**T**<br> *in* B:**T**<br> *in* input_mean:**U**<br> *in* input_var:**U**<br> *out* Y:**T**<br> *out* running_mean:**U**<br> *out* running_var:**U**<br><br>or<br><br>*in* X:**T**<br> *in* scale:**T**<br> *in* B:**T**<br> *in* mean:**T**<br> *in* var:**T**<br> *out* Y:**T**<br> *out* mean:**T**<br> *out* var:**T**<br> *out* saved_mean:**T**<br> *out* saved_var:**T**<br><br>or<br><br>*in* X:**T**<br> *in* scale:**T1**<br> *in* B:**T1**<br> *in* input_mean:**T2**<br> *in* input_var:**T2**<br> *out* Y:**T**<br> *out* running_mean:**T2**<br> *out* running_var:**T2**|15+|**T** = tensor(double), tensor(float), tensor(float16)<br/> **T1** = tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(double), tensor(float), tensor(float16)|
|
||||
|||14|**T** = tensor(double), tensor(float), tensor(float16)<br/> **U** = tensor(double), tensor(float), tensor(float16)|
|
||||
|||[9, 13]|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|Cast|*in* input:**T1**<br> *out* output:**T2**|13+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|
|
@ -582,7 +586,8 @@ Do not modify directly.*
|
|||
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|||[2, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|ParametricSoftplus|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|Pow|*in* X:**T**<br> *in* Y:**T**<br> *out* Z:**T**<br><br>or<br><br>*in* X:**T**<br> *in* Y:**T1**<br> *out* Z:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)<br/> **T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
|
||||
|Pow|*in* X:**T**<br> *in* Y:**T**<br> *out* Z:**T**<br><br>or<br><br>*in* X:**T**<br> *in* Y:**T1**<br> *out* Z:**T**|15+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)<br/> **T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
|
||||
|||[13, 14]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)<br/> **T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
|
||||
|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)<br/> **T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
|
||||
|||[7, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|10+|**T1** = tensor(float)<br/> **T2** = tensor(int8), tensor(uint8)|
|
||||
|
|
@ -653,7 +658,8 @@ Do not modify directly.*
|
|||
|SequenceErase|*in* input_sequence:**S**<br> *in* position:**I**<br> *out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
|
||||
|SequenceInsert|*in* input_sequence:**S**<br> *in* tensor:**T**<br> *in* position:**I**<br> *out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
|
||||
|SequenceLength|*in* input_sequence:**S**<br> *out* length:**I**|11+|**I** = tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
|
||||
|Shape|*in* data:**T**<br> *out* shape:**T1**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|
||||
|Shape|*in* data:**T**<br> *out* shape:**T1**|15+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|
||||
|||[13, 14]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|
||||
|||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|
||||
|Shrink|*in* input:**T**<br> *out* output:**T**|9+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|Sigmoid|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|
||||
|
|
|
|||
|
|
@ -72,6 +72,15 @@ class TensorShape : private std::vector<int64_t> {
|
|||
memcpy(dims, data(), sizeof(value_type) * std::min(num_dims, NumDimensions()));
|
||||
}
|
||||
|
||||
/**
|
||||
Copy dims from a specific start dim into an array with given size
|
||||
`start_dim` is expected to be in the inclusive range [0, NumDimensions() - 1]
|
||||
and this function does no checks to ensure that
|
||||
*/
|
||||
void CopyDims(int64_t* dims, size_t start_dim, size_t num_dims) const {
|
||||
memcpy(dims, data() + start_dim, sizeof(value_type) * std::min(num_dims, NumDimensions() - start_dim));
|
||||
}
|
||||
|
||||
/**
|
||||
Return underlying vector representation.
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "core/optimizer/constant_folding.h"
|
||||
#include "core/optimizer/utils.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
|
|
@ -25,6 +27,20 @@ ConstantFolding::ConstantFolding(const IExecutionProvider& execution_provider,
|
|||
// We need to handle a Shape node separately as the input doesn't need to be a constant initializer for
|
||||
// Shape to be able to be constant folded.
|
||||
static bool ConstantFoldShapeNode(Graph& graph, Node& node) {
|
||||
// Opset-15 Shape supports slicing using a 'start' and 'end' attribute
|
||||
const auto& shape_attributes = node.GetAttributes();
|
||||
|
||||
int64_t start = 0;
|
||||
int64_t end = std::numeric_limits<int64_t>::max();
|
||||
|
||||
for (const auto& attr : shape_attributes) {
|
||||
if (attr.first == "start") {
|
||||
start = attr.second.i();
|
||||
} else if (attr.first == "end") {
|
||||
end = attr.second.i();
|
||||
}
|
||||
}
|
||||
|
||||
auto shape = node.MutableInputDefs()[0]->Shape();
|
||||
bool is_concrete_shape = true;
|
||||
std::vector<int64_t> dim_values;
|
||||
|
|
@ -42,14 +58,30 @@ static bool ConstantFoldShapeNode(Graph& graph, Node& node) {
|
|||
}
|
||||
|
||||
if (is_concrete_shape) {
|
||||
int64_t rank = static_cast<int64_t>(dim_values.size());
|
||||
|
||||
// We ascertain the "true" starts/ends (if they were provided)
|
||||
// Opset-15 Shape op supports slicing shape values
|
||||
|
||||
// Deal with negatives and clamp
|
||||
start = start < 0 ? start + rank : start;
|
||||
start = start < 0 ? 0 : ((start > rank) ? rank : start);
|
||||
|
||||
end = end < 0 ? end + rank : end;
|
||||
end = end < 0 ? 0 : ((end > rank) ? rank : end);
|
||||
|
||||
int64_t slice_length = end - start;
|
||||
size_t clamped_slice_length = slice_length < 0 ? 0 : static_cast<size_t>(slice_length);
|
||||
|
||||
ONNX_NAMESPACE::TensorProto shape_constant;
|
||||
auto* constant_arg_out = node.MutableOutputDefs()[0];
|
||||
shape_constant.set_name(constant_arg_out->Name());
|
||||
shape_constant.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
|
||||
shape_constant.add_dims(dim_values.size());
|
||||
shape_constant.set_raw_data(dim_values.data(), dim_values.size() * sizeof(int64_t));
|
||||
shape_constant.add_dims(clamped_slice_length);
|
||||
shape_constant.set_raw_data(dim_values.data() + start,
|
||||
clamped_slice_length * sizeof(int64_t));
|
||||
ONNX_NAMESPACE::TensorShapeProto result_shape;
|
||||
result_shape.add_dim()->set_dim_value(dim_values.size());
|
||||
result_shape.add_dim()->set_dim_value(clamped_slice_length);
|
||||
constant_arg_out->SetShape(result_shape);
|
||||
graph.AddInitializedTensor(shape_constant);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -37,7 +37,6 @@
|
|||
#include "core/optimizer/relu_clip_fusion.h"
|
||||
#include "core/optimizer/reshape_fusion.h"
|
||||
#include "core/optimizer/rule_based_graph_transformer.h"
|
||||
#include "core/optimizer/shape_to_initializer.h"
|
||||
#include "core/optimizer/skip_layer_norm_fusion.h"
|
||||
#include "core/optimizer/slice_elimination.h"
|
||||
#include "core/optimizer/unsqueeze_elimination.h"
|
||||
|
|
@ -75,7 +74,6 @@ std::vector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(
|
|||
rules.push_back(std::make_unique<FuseReluClip>());
|
||||
rules.push_back(std::make_unique<GemmTransposeFusion>());
|
||||
rules.push_back(std::make_unique<NotWhereFusion>());
|
||||
rules.push_back(std::make_unique<ShapeToInitializer>());
|
||||
rules.push_back(std::make_unique<ConvAddFusion>());
|
||||
rules.push_back(std::make_unique<ConvMulFusion>());
|
||||
rules.push_back(std::make_unique<ConvBNFusion>());
|
||||
|
|
|
|||
|
|
@ -1,80 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/optimizer/shape_to_initializer.h"
|
||||
#include "core/graph/graph.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "core/graph/op.h"
|
||||
#include "core/optimizer/initializer.h"
|
||||
#include "core/optimizer/optimizer_execution_frame.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
Status ShapeToInitializer::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
|
||||
// Store the statically inferred shape of the input to the Shape operator.
|
||||
const ONNX_NAMESPACE::TensorShapeProto* input_shape_proto = node.InputDefs()[0]->Shape();
|
||||
std::vector<int64_t> input_dims;
|
||||
int num_dimensions = input_shape_proto->dim_size();
|
||||
for (int i = 0; i < num_dimensions; i++) {
|
||||
input_dims.push_back(gsl::narrow_cast<int64_t>(input_shape_proto->dim(i).dim_value()));
|
||||
}
|
||||
|
||||
// Create the TensorProto that will be used as initializer in place of the Shape operator.
|
||||
const auto* shape_out_def = node.OutputDefs()[0];
|
||||
|
||||
ONNX_NAMESPACE::TensorProto shape_initializer_proto;
|
||||
|
||||
shape_initializer_proto.set_name(shape_out_def->Name());
|
||||
|
||||
TensorShape tensor_shape({gsl::narrow_cast<int64_t>(num_dimensions)});
|
||||
for (auto& dim : tensor_shape.GetDims()) {
|
||||
shape_initializer_proto.add_dims(dim);
|
||||
}
|
||||
|
||||
auto tensor_proto_data_type = shape_out_def->TypeAsProto()->tensor_type().elem_type();
|
||||
shape_initializer_proto.set_data_type(tensor_proto_data_type);
|
||||
|
||||
// Here we expect little-endian format to set raw data of the TensorProto.
|
||||
shape_initializer_proto.set_raw_data(input_dims.data(),
|
||||
input_dims.size() * sizeof(decltype(input_dims)::value_type));
|
||||
|
||||
auto& new_node_arg = graph_utils::AddInitializer(graph, shape_initializer_proto);
|
||||
|
||||
if (graph_utils::ReplaceNodeWithInitializer(graph, node, new_node_arg)) {
|
||||
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool ShapeToInitializer::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Shape", {1, 13})) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// The shape of the input has to be statically known. Moreover, each dimension should have a meaningful value
|
||||
// (the rule cannot be applied if one of the dimensions has a negative value or if it is a symbolic variable).
|
||||
const auto* input_shape = node.InputDefs()[0]->Shape();
|
||||
if (!input_shape) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int i = 0, num_dims = input_shape->dim_size(); i < num_dims; i++) {
|
||||
const auto& input_dim = input_shape->dim(i);
|
||||
if (!utils::HasDimValue(input_dim) || input_dim.dim_value() < 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// we're going to create an initializer with the same name as the node output
|
||||
const auto& new_initializer_name = node.OutputDefs()[0]->Name();
|
||||
if (!graph_utils::CanReplaceNodeWithInitializer(graph, node, new_initializer_name, logger)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1,32 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/optimizer/rewrite_rule.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
/**
|
||||
@Class ShapeToInitializer
|
||||
|
||||
When the input to a Shape operator is statically known (through shape inference), this rule replaces the Shape node
|
||||
with an initializer to the downstream nodes.
|
||||
|
||||
It is attempted to be triggered only on nodes with op type "Shape".
|
||||
*/
|
||||
class ShapeToInitializer : public RewriteRule {
|
||||
public:
|
||||
ShapeToInitializer() noexcept : RewriteRule("ShapeToInitializer") {}
|
||||
|
||||
std::vector<std::string> TargetOpTypes() const noexcept override {
|
||||
return {"Shape"};
|
||||
}
|
||||
|
||||
private:
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
|
||||
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -527,7 +527,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, ArgMin);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, ArgMin);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 13, Reshape);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Shape);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 14, Shape);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Concat);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Less);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Less);
|
||||
|
|
@ -590,7 +590,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Exp);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Log);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Log);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Pow);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 14, Pow);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, DepthToSpace);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, SpaceToDepth);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Slice);
|
||||
|
|
@ -686,12 +686,17 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, int64_t, Div);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, Reshape);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, Identity);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, float, BatchNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, double, BatchNormalization);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, GRU);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, LSTM);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, RNN);
|
||||
|
||||
// Opset 15
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 15, Pow);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 15, float, BatchNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 15, double, BatchNormalization);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 15, Shape);
|
||||
|
||||
// !!PLEASE READ BELOW!! Following that, add new entries above this comment
|
||||
|
||||
|
|
@ -1151,9 +1156,9 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int64_t,
|
||||
MatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 13, float,
|
||||
BatchNormalization)>,
|
||||
BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 13, double,
|
||||
BatchNormalization)>,
|
||||
BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, PRelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 9,
|
||||
float, Upsample)>,
|
||||
|
|
@ -1547,7 +1552,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
int32_t, ArgMin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 13,
|
||||
Reshape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Shape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 14, Shape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Concat)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, bool,
|
||||
Equal)>,
|
||||
|
|
@ -1650,7 +1655,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Exp)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Log)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Log)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Pow)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 14, Pow)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Split)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Unsqueeze)>,
|
||||
|
|
@ -1810,13 +1815,22 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
Div)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, Reshape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, Identity)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, float,
|
||||
BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, double,
|
||||
BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, 14, float,
|
||||
BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, 14, double,
|
||||
BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, GRU)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, LSTM)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, RNN)>,
|
||||
|
||||
// Opset 15
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 15, Pow)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 15, float,
|
||||
BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 15, double,
|
||||
BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 15, Shape)>,
|
||||
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
|
|
|
|||
|
|
@ -259,13 +259,27 @@ REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Sqrt, 6, 12, double, Sqrt);
|
|||
REG_ELEMENTWISE_TYPED_KERNEL(Sqrt, 13, float, Sqrt);
|
||||
REG_ELEMENTWISE_TYPED_KERNEL(Sqrt, 13, double, Sqrt);
|
||||
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Pow, 7, 11, Pow, BuildKernelDefConstraintsFromTypeList<Pow7Types>(), BuildKernelDefConstraintsFromTypeList<EnabledPow7Types>());
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Pow, 7, 11, Pow,
|
||||
BuildKernelDefConstraintsFromTypeList<Pow7Types>(),
|
||||
BuildKernelDefConstraintsFromTypeList<EnabledPow7Types>());
|
||||
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL_NONT_2(Pow, 12, 12, Pow,
|
||||
BuildKernelDefConstraintsFromTypeList<Pow12BaseTypes>(), BuildKernelDefConstraintsFromTypeList<EnabledPow12BaseTypes>(),
|
||||
BuildKernelDefConstraintsFromTypeList<Pow12ExpTypes>(), BuildKernelDefConstraintsFromTypeList<EnabledPow12ExpTypes>());
|
||||
REG_ELEMENTWISE_KERNEL_NONT_2(Pow, 13, Pow,
|
||||
BuildKernelDefConstraintsFromTypeList<Pow12BaseTypes>(), BuildKernelDefConstraintsFromTypeList<EnabledPow12BaseTypes>(),
|
||||
BuildKernelDefConstraintsFromTypeList<Pow12ExpTypes>(), BuildKernelDefConstraintsFromTypeList<EnabledPow12ExpTypes>());
|
||||
BuildKernelDefConstraintsFromTypeList<Pow12BaseTypes>(),
|
||||
BuildKernelDefConstraintsFromTypeList<EnabledPow12BaseTypes>(),
|
||||
BuildKernelDefConstraintsFromTypeList<Pow12ExpTypes>(),
|
||||
BuildKernelDefConstraintsFromTypeList<EnabledPow12ExpTypes>());
|
||||
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL_NONT_2(Pow, 13, 14, Pow,
|
||||
BuildKernelDefConstraintsFromTypeList<Pow12BaseTypes>(),
|
||||
BuildKernelDefConstraintsFromTypeList<EnabledPow12BaseTypes>(),
|
||||
BuildKernelDefConstraintsFromTypeList<Pow12ExpTypes>(),
|
||||
BuildKernelDefConstraintsFromTypeList<EnabledPow12ExpTypes>());
|
||||
|
||||
REG_ELEMENTWISE_KERNEL_NONT_2(Pow, 15, Pow,
|
||||
BuildKernelDefConstraintsFromTypeList<Pow12BaseTypes>(),
|
||||
BuildKernelDefConstraintsFromTypeList<EnabledPow12BaseTypes>(),
|
||||
BuildKernelDefConstraintsFromTypeList<Pow12ExpTypes>(),
|
||||
BuildKernelDefConstraintsFromTypeList<EnabledPow12ExpTypes>());
|
||||
|
||||
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Exp, 6, 12, float, Exp);
|
||||
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Exp, 6, 12, double, Exp);
|
||||
|
|
|
|||
|
|
@ -31,18 +31,52 @@ ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(BatchNormalization, 7, 8, double,
|
|||
|
||||
// We alias the running mean to the mean so it stays preserved across multiple batches
|
||||
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(BatchNormalization, 9, 13, float,
|
||||
KernelDefBuilder().Alias(3,1).Alias(4,2).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
BatchNorm<float>);
|
||||
KernelDefBuilder().Alias(3, 1).Alias(4, 2).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
BatchNorm<float>);
|
||||
|
||||
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(BatchNormalization, 9, 13, double,
|
||||
KernelDefBuilder().Alias(3,1).Alias(4,2).TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
|
||||
BatchNorm<double>);
|
||||
KernelDefBuilder().Alias(3, 1).Alias(4, 2).TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
|
||||
BatchNorm<double>);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(BatchNormalization, 14, float,
|
||||
KernelDefBuilder().Alias(3,1).Alias(4,2).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(BatchNormalization, 14, 14, float,
|
||||
KernelDefBuilder()
|
||||
.Alias(3, 1)
|
||||
.Alias(4, 2)
|
||||
// ORT 1.8 was shipped with just the "T" type constraint and
|
||||
// we want to maintain backwards compatibility for
|
||||
// the hash and hence just use "T" for the hash generation
|
||||
.FixedTypeConstraintForHash("T", {DataTypeImpl::GetTensorType<float>()})
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
|
||||
.TypeConstraint("U", DataTypeImpl::GetTensorType<float>()),
|
||||
BatchNorm<float>);
|
||||
|
||||
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(BatchNormalization, 14, 14, double,
|
||||
KernelDefBuilder()
|
||||
.Alias(3, 1)
|
||||
.Alias(4, 2)
|
||||
// ORT 1.8 was shipped with just the "T" type constraint and
|
||||
// we want to maintain backwards compatibility for
|
||||
// the hash and hence just use "T" for the hash generation
|
||||
.FixedTypeConstraintForHash("T", {DataTypeImpl::GetTensorType<double>()})
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<double>())
|
||||
.TypeConstraint("U", DataTypeImpl::GetTensorType<double>()),
|
||||
BatchNorm<double>);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(BatchNormalization, 15, float,
|
||||
KernelDefBuilder()
|
||||
.Alias(3, 1)
|
||||
.Alias(4, 2)
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<float>())
|
||||
.TypeConstraint("T2", DataTypeImpl::GetTensorType<float>()),
|
||||
BatchNorm<float>);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(BatchNormalization, 14, double,
|
||||
KernelDefBuilder().Alias(3,1).Alias(4,2).TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(BatchNormalization, 15, double,
|
||||
KernelDefBuilder()
|
||||
.Alias(3, 1)
|
||||
.Alias(4, 2)
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<double>())
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<double>())
|
||||
.TypeConstraint("T2", DataTypeImpl::GetTensorType<double>()),
|
||||
BatchNorm<double>);
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -12,9 +12,16 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
|||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorTypes()).TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>()),
|
||||
Shape);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
||||
Shape,
|
||||
13,
|
||||
13, 14,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorTypes()).TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>()),
|
||||
Shape);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
Shape,
|
||||
15,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorTypes()).TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>()),
|
||||
Shape);
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -9,25 +9,62 @@
|
|||
#endif
|
||||
|
||||
#include "gsl/gsl"
|
||||
#include <limits>
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
class Shape final : public OpKernel {
|
||||
public:
|
||||
Shape(const OpKernelInfo& info) : OpKernel(info) {
|
||||
info.GetAttrOrDefault<int64_t>("start", &start_index_, 0);
|
||||
|
||||
if (start_index_ != 0) {
|
||||
// "start" is provided and is non-default (default is 0)
|
||||
needs_slicing_ = true;
|
||||
}
|
||||
|
||||
if (info.GetAttr<int64_t>("end", &end_index_).IsOK()) {
|
||||
needs_slicing_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Takes a tensor as input and outputs an 1D int64 tensor
|
||||
// containing the shape of the input tensor.
|
||||
Status Compute(OpKernelContext* context) const override {
|
||||
const auto* input = context->Input<Tensor>(0);
|
||||
const TensorShape& inputShape = input->Shape();
|
||||
const TensorShape& input_shape = input->Shape();
|
||||
|
||||
size_t nDims = inputShape.NumDimensions();
|
||||
Tensor* output = context->Output(0, {gsl::narrow_cast<int64_t>(nDims)});
|
||||
int64_t rank = gsl::narrow_cast<int64_t>(input_shape.NumDimensions());
|
||||
|
||||
if (!needs_slicing_) { // vanilla use of Shape (no slicing)
|
||||
Tensor* output = context->Output(0, {rank});
|
||||
input_shape.CopyDims(output->template MutableData<int64_t>(), static_cast<size_t>(rank));
|
||||
} else { // slicing is needed
|
||||
int64_t true_start = start_index_;
|
||||
int64_t true_end = end_index_;
|
||||
|
||||
// Deal with negative(s) and clamp
|
||||
true_start = true_start < 0 ? true_start + rank : true_start;
|
||||
true_start = true_start < 0 ? 0 : ((true_start > rank) ? rank : true_start);
|
||||
|
||||
true_end = true_end < 0 ? true_end + rank : true_end;
|
||||
true_end = true_end < 0 ? 0 : ((true_end > rank) ? rank : true_end);
|
||||
|
||||
auto slice_length = true_end - true_start;
|
||||
Tensor* output = context->Output(0, {slice_length < 0 ? 0 : slice_length});
|
||||
|
||||
if (slice_length > 0) {
|
||||
input_shape.CopyDims(output->template MutableData<int64_t>(), true_start, slice_length);
|
||||
}
|
||||
}
|
||||
|
||||
inputShape.CopyDims(output->template MutableData<int64_t>(), nDims);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
bool needs_slicing_ = false;
|
||||
int64_t start_index_ = 0;
|
||||
int64_t end_index_ = std::numeric_limits<int64_t>::max();
|
||||
};
|
||||
|
||||
} //namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -892,7 +892,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, Einsum);
|
||||
|
||||
//OpSet 13
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Pow);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Pow);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int32_t, Add);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int64_t, Add);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, uint32_t, Add);
|
||||
|
|
@ -1005,7 +1005,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint64_t, Cast);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, bool, Cast);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, Reshape);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Shape);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Shape);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Size);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Transpose);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, ScatterElements);
|
||||
|
|
@ -1163,9 +1163,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, LSTM);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, LSTM);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, LSTM);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, BatchNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, BatchNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, BatchNormalization);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, ReduceMin);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, ReduceMin);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, ReduceMin);
|
||||
|
|
@ -1182,6 +1182,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Relu);
|
||||
#endif
|
||||
|
||||
//OpSet 15
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, Pow);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, float, BatchNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, double, BatchNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, MLFloat16, BatchNormalization);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, Shape);
|
||||
|
||||
template <>
|
||||
KernelCreateInfo BuildKernelCreateInfo<void>() {
|
||||
return {};
|
||||
|
|
@ -1728,7 +1735,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, Einsum)>,
|
||||
|
||||
// OpSet 13
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Pow)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Pow)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int32_t, Add)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int64_t, Add)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, uint32_t, Add)>,
|
||||
|
|
@ -1841,7 +1848,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint64_t, Cast)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, bool, Cast)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, Reshape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Shape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Shape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Size)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Transpose)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, ScatterElements)>,
|
||||
|
|
@ -1998,9 +2005,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, LSTM)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, LSTM)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, Reshape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, ReduceMin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, ReduceMin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, ReduceMin)>,
|
||||
|
|
@ -2015,6 +2022,14 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Div)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Relu)>,
|
||||
#endif
|
||||
|
||||
// OpSet 15
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, Pow)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, float, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, double, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, MLFloat16, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, Shape)>,
|
||||
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
|
|
|
|||
|
|
@ -205,9 +205,9 @@ Status BinaryElementwise<ShouldBroadcast>::Prepare(OpKernelContext* context, Bin
|
|||
#define BINARY_OP_TYPED_VERSIONED_V_BF16(name, class_name, startver, endver)
|
||||
#endif
|
||||
|
||||
#define BINARY_OP_VERSIONED_HFD(name, startver, endver) \
|
||||
BINARY_OP_VERSIONED_TYPED(name, startver, endver, MLFloat16) \
|
||||
BINARY_OP_VERSIONED_TYPED(name, startver, endver, float) \
|
||||
#define BINARY_OP_VERSIONED_HFD(name, startver, endver) \
|
||||
BINARY_OP_VERSIONED_TYPED(name, startver, endver, MLFloat16) \
|
||||
BINARY_OP_VERSIONED_TYPED(name, startver, endver, float) \
|
||||
BINARY_OP_VERSIONED_TYPED(name, startver, endver, double)
|
||||
|
||||
#define BINARY_OP_VERSIONED_UZILHFD(name, startver, endver) \
|
||||
|
|
@ -318,15 +318,29 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
|||
kOnnxDomain,
|
||||
12, 12,
|
||||
kCudaExecutionProvider,
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints<int32_t, int64_t, float, double, MLFloat16>()).TypeConstraint("T1", BuildKernelDefConstraints<int32_t, int64_t, float, double, MLFloat16>()),
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", BuildKernelDefConstraints<int32_t, int64_t, float, double, MLFloat16>())
|
||||
.TypeConstraint("T1", BuildKernelDefConstraints<int32_t, int64_t, float, double, MLFloat16>()),
|
||||
Pow);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Pow,
|
||||
kOnnxDomain,
|
||||
13, 14,
|
||||
kCudaExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", BuildKernelDefConstraints<int32_t, int64_t, float, double, MLFloat16>())
|
||||
.TypeConstraint("T1", BuildKernelDefConstraints<int32_t, int64_t, float, double, MLFloat16>()),
|
||||
Pow);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
Pow,
|
||||
kOnnxDomain,
|
||||
13,
|
||||
15,
|
||||
kCudaExecutionProvider,
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints<int32_t, int64_t, float, double, MLFloat16>()).TypeConstraint("T1", BuildKernelDefConstraints<int32_t, int64_t, float, double, MLFloat16>()),
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", BuildKernelDefConstraints<int32_t, int64_t, float, double, MLFloat16>())
|
||||
.TypeConstraint("T1", BuildKernelDefConstraints<int32_t, int64_t, float, double, MLFloat16>()),
|
||||
Pow);
|
||||
|
||||
namespace pow12_internal {
|
||||
|
|
@ -524,6 +538,5 @@ BINARY_OP_REGISTER_VERSIONED_HFD(Less, 7, 8)
|
|||
BINARY_LOGICALOP_REGISTER_UZILHFD(GreaterOrEqual, 12)
|
||||
BINARY_LOGICALOP_REGISTER_UZILHFD(LessOrEqual, 12)
|
||||
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -11,33 +11,45 @@ using namespace std;
|
|||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
#define REGISTER_KERNEL_TYPED(T) \
|
||||
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
|
||||
BatchNormalization, \
|
||||
kOnnxDomain, \
|
||||
7, 8, \
|
||||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
BatchNorm<T>); \
|
||||
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
|
||||
BatchNormalization, \
|
||||
kOnnxDomain, \
|
||||
9, 13, \
|
||||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
BatchNorm<T>); \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
BatchNormalization, \
|
||||
kOnnxDomain, \
|
||||
14, \
|
||||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
#define REGISTER_KERNEL_TYPED(T) \
|
||||
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
|
||||
BatchNormalization, \
|
||||
kOnnxDomain, \
|
||||
7, 8, \
|
||||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
BatchNorm<T>); \
|
||||
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
|
||||
BatchNormalization, \
|
||||
kOnnxDomain, \
|
||||
9, 13, \
|
||||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
BatchNorm<T>); \
|
||||
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
|
||||
BatchNormalization, \
|
||||
kOnnxDomain, \
|
||||
14, 14, \
|
||||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
|
||||
.TypeConstraint("U", DataTypeImpl::GetTensorType<T>()), \
|
||||
BatchNorm<T>); \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
BatchNormalization, \
|
||||
kOnnxDomain, \
|
||||
15, \
|
||||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<T>()) \
|
||||
.TypeConstraint("T2", DataTypeImpl::GetTensorType<T>()), \
|
||||
BatchNorm<T>);
|
||||
|
||||
template <typename T>
|
||||
|
|
|
|||
|
|
@ -39,8 +39,8 @@ class BatchNorm final : public CudaKernel {
|
|||
const auto& node = op_kernel_info.node();
|
||||
auto opset = node.SinceVersion();
|
||||
|
||||
// batch norm opset 14 is not implemented for training mode
|
||||
ORT_ENFORCE(!(is_training_mode_ && opset==14), "Training mode does not support BN opset 14 yet.");
|
||||
// batch norm opset 14 (or higher) is not implemented for training mode
|
||||
ORT_ENFORCE(!(is_training_mode_ && opset >= 14), "Training mode does not support BN opset 14 (or higher) yet.");
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
|
@ -50,7 +50,7 @@ class BatchNorm final : public CudaKernel {
|
|||
int64_t spatial_ = 1; // default as per spec
|
||||
cudnnBatchNormMode_t cudnn_batch_norm_mode_;
|
||||
double momentum_;
|
||||
bool is_training_mode_ = 0; //default as per spec
|
||||
bool is_training_mode_ = 0; //default as per spec
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
|||
|
|
@ -20,10 +20,22 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
|||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>()),
|
||||
Shape);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Shape,
|
||||
kOnnxDomain,
|
||||
13, 14,
|
||||
kCudaExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
// properly force CPU/GPU synch inside the kernel
|
||||
.OutputMemoryType(OrtMemTypeCPUInput, 0)
|
||||
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>()),
|
||||
Shape);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
Shape,
|
||||
kOnnxDomain,
|
||||
13,
|
||||
15,
|
||||
kCudaExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
// properly force CPU/GPU synch inside the kernel
|
||||
|
|
|
|||
|
|
@ -827,7 +827,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDom
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, Einsum);
|
||||
|
||||
//OpSet 13
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Pow);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 14, Pow);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, int32_t, Add);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, int64_t, Add);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, uint32_t, Add);
|
||||
|
|
@ -934,7 +934,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint64_t, Cast);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, bool, Cast);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, Reshape);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Shape);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 14, Shape);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Size);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Transpose);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, ScatterElements);
|
||||
|
|
@ -1029,6 +1029,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Pad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Pad);
|
||||
|
||||
// opset 15
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 15, Shape);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 15, Pow);
|
||||
|
||||
template <>
|
||||
KernelCreateInfo BuildKernelCreateInfo<void>() {
|
||||
KernelCreateInfo info;
|
||||
|
|
@ -1555,7 +1559,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, Einsum)>,
|
||||
|
||||
// OpSet 13
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Pow)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 14, Pow)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, int32_t, Add)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, int64_t, Add)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, uint32_t, Add)>,
|
||||
|
|
@ -1662,7 +1666,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint64_t, Cast)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, bool, Cast)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, Reshape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Shape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 14, Shape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Size)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Transpose)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, ScatterElements)>,
|
||||
|
|
@ -1756,6 +1760,10 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Pad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Pad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Pad)>,
|
||||
|
||||
// opset 15
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 15, Shape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 15, Pow)>,
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
|
|
|
|||
|
|
@ -291,7 +291,8 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
|
|||
double per_sample_tolerance = 1e-3;
|
||||
// when cuda is enabled, set it to a larger value for resolving random MNIST test failure
|
||||
// when openvino is enabled, set it to a larger value for resolving MNIST accuracy mismatch
|
||||
double relative_per_sample_tolerance = enable_cuda ? 0.017 : enable_openvino ? 0.009 : 1e-3;
|
||||
double relative_per_sample_tolerance = enable_cuda ? 0.017 : enable_openvino ? 0.009
|
||||
: 1e-3;
|
||||
|
||||
Ort::SessionOptions sf;
|
||||
|
||||
|
|
@ -480,8 +481,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
|
|||
ORT_TSTR("operator_pow"),
|
||||
ORT_TSTR("bernoulli"),
|
||||
ORT_TSTR("bernoulli_double"),
|
||||
ORT_TSTR("bernoulli_seed")
|
||||
};
|
||||
ORT_TSTR("bernoulli_seed")};
|
||||
|
||||
static const ORTCHAR_T* cuda_flaky_tests[] = {
|
||||
ORT_TSTR("fp16_inception_v1"),
|
||||
|
|
@ -600,16 +600,6 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
|
|||
{"bernoulli_seed", "By design. Test data is for informational purpose because the generator is non deterministic."},
|
||||
{"bernoulli_seed_expanded", "By design. Test data is for informational purpose because the generator is non deterministic."},
|
||||
{"bernoulli_expanded", "By design. Test data is for informational purpose because the generator is non deterministic."},
|
||||
{"shape", "opset15 updates not supported yet."},
|
||||
{"shape_clip_end", "opset15 updates not supported yet."},
|
||||
{"shape_clip_start", "opset15 updates not supported yet."},
|
||||
{"shape_end_1", "opset15 updates not supported yet."},
|
||||
{"shape_end_negative_1", "opset15 updates not supported yet."},
|
||||
{"shape_example", "opset15 updates not supported yet."},
|
||||
{"shape_start_1", "opset15 updates not supported yet."},
|
||||
{"shape_start_1_end_2", "opset15 updates not supported yet."},
|
||||
{"shape_start_1_end_negative_1", "opset15 updates not supported yet."},
|
||||
{"shape_start_negative_1", "opset15 updates not supported yet."},
|
||||
{"test_optional_get_element", "opset15 updates not supported yet."},
|
||||
{"test_optional_get_element_sequence", "opset15 updates not supported yet."},
|
||||
{"test_optional_has_element", "opset15 updates not supported yet."},
|
||||
|
|
|
|||
|
|
@ -53,7 +53,6 @@
|
|||
#include "core/optimizer/relu_clip_fusion.h"
|
||||
#include "core/optimizer/reshape_fusion.h"
|
||||
#include "core/optimizer/rule_based_graph_transformer.h"
|
||||
#include "core/optimizer/shape_to_initializer.h"
|
||||
#include "core/optimizer/skip_layer_norm_fusion.h"
|
||||
#include "core/optimizer/slice_elimination.h"
|
||||
#include "core/optimizer/unsqueeze_elimination.h"
|
||||
|
|
@ -480,7 +479,7 @@ TEST_F(GraphTransformationTests, ConstantFolding_RemoveDanglingInputNodesToConst
|
|||
ASSERT_TRUE(op_to_count["RandomUniform"] == 0);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, ShapeToInitializer) {
|
||||
TEST_F(GraphTransformationTests, ConstantFoldingAShapeNodeDeepInTheGraph) {
|
||||
auto model_uri = MODEL_FOLDER "shape-add.onnx";
|
||||
std::shared_ptr<Model> model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_));
|
||||
|
|
@ -489,17 +488,21 @@ TEST_F(GraphTransformationTests, ShapeToInitializer) {
|
|||
ASSERT_TRUE(op_to_count["Shape"] == 4);
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
|
||||
rule_transformer_L1->Register(std::make_unique<ShapeToInitializer>());
|
||||
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
|
||||
|
||||
std::unique_ptr<CPUExecutionProvider> e =
|
||||
std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
|
||||
graph_transformation_mgr.Register(std::make_unique<ConstantFolding>(*e.get(),
|
||||
false /*skip_dequantize_linear*/),
|
||||
TransformerLevel::Level1);
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
// Two of the Shapes are not eliminated because:
|
||||
// One includes a symbolic dimension.
|
||||
// Another one includes a negative dimension
|
||||
ASSERT_TRUE(op_to_count["Shape"] == 2);
|
||||
|
||||
// A Shape node very deep in the graph (feeding into an Identity
|
||||
// node that produces the graph output) gets constant folded which
|
||||
// removes all its ancestors and the Identity node consuming this Shape's
|
||||
// output is subsequently constant folded to leave the graph with no
|
||||
// nodes.
|
||||
ASSERT_TRUE(op_to_count.size() == 0);
|
||||
}
|
||||
|
||||
// Check transformations in the case of a subgraph with constant inputs.
|
||||
|
|
@ -674,8 +677,8 @@ TEST_F(GraphTransformationTests, FuseCudaConvAddRelu) {
|
|||
graph_transformation_mgr.Register(std::make_unique<ConvActivationFusion>(), TransformerLevel::Level2);
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 0); //Add removed from graph
|
||||
ASSERT_TRUE(op_to_count["Relu"] == 0); //Relu removed from graph
|
||||
ASSERT_TRUE(op_to_count["Add"] == 0); //Add removed from graph
|
||||
ASSERT_TRUE(op_to_count["Relu"] == 0); //Relu removed from graph
|
||||
}
|
||||
|
||||
//Conv->Add->Relu will be left intact since there is Identity depend on Add
|
||||
|
|
@ -695,9 +698,9 @@ TEST_F(GraphTransformationTests, FuseCudaConvAddReluIdentity) {
|
|||
graph_transformation_mgr.Register(std::make_unique<ConvActivationFusion>(), TransformerLevel::Level2);
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 1); //Add remains
|
||||
ASSERT_TRUE(op_to_count["Relu"] == 1); //Relu remains
|
||||
ASSERT_TRUE(op_to_count["Identity"] == 1); //Identity remains
|
||||
ASSERT_TRUE(op_to_count["Add"] == 1); //Add remains
|
||||
ASSERT_TRUE(op_to_count["Relu"] == 1); //Relu remains
|
||||
ASSERT_TRUE(op_to_count["Identity"] == 1); //Identity remains
|
||||
}
|
||||
|
||||
//Conv->Add will be left intact since there is no Relu follows
|
||||
|
|
@ -715,7 +718,7 @@ TEST_F(GraphTransformationTests, FuseCudaConvAdd) {
|
|||
graph_transformation_mgr.Register(std::make_unique<ConvActivationFusion>(), TransformerLevel::Level2);
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 1); //Add remains, no transform applied to the graph
|
||||
ASSERT_TRUE(op_to_count["Add"] == 1); //Add remains, no transform applied to the graph
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
@ -4131,13 +4134,13 @@ TEST_F(GraphTransformationTests, FilterEnabledOptimizers) {
|
|||
|
||||
const auto& graph = session_object.GetGraph();
|
||||
|
||||
// check the ops that should go away if the constant folding transformer or ShapeToInitializer rewrite rule run
|
||||
// check the ops that should go away if the constant folding transformer runs
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Shape"] == 1);
|
||||
ASSERT_TRUE(op_to_count["ConstantOfShape"] == 1);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 1);
|
||||
|
||||
ASSERT_STATUS_OK(session_object.FilterEnabledOptimizers({"ConstantFolding", "ShapeToInitializer"}));
|
||||
ASSERT_STATUS_OK(session_object.FilterEnabledOptimizers({"ConstantFolding"}));
|
||||
ASSERT_STATUS_OK(session_object.Initialize()); // Initialize runs the transformers
|
||||
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
|
|
|
|||
|
|
@ -688,6 +688,21 @@ TEST(MathOpTest, Pow_Float_12) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Pow_Float_15) {
|
||||
OpTester test("Pow", 15);
|
||||
std::vector<int64_t> dims{2, 2};
|
||||
test.AddInput<float>("X", dims,
|
||||
{2.0f, 2.0f,
|
||||
std::sqrt(2.0f), 1.0f});
|
||||
test.AddInput<float>("Y", dims,
|
||||
{0.0f, 8.0f,
|
||||
2.0f, 9.0f});
|
||||
test.AddOutput<float>("Z", dims,
|
||||
{1.0f, 256.0f,
|
||||
2.0f, 1.0f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Pow_Double_12) {
|
||||
OpTester test("Pow", 12);
|
||||
std::vector<int64_t> dims{2, 2};
|
||||
|
|
@ -1635,7 +1650,7 @@ TEST(MathOpTest, LessOrEqual) {
|
|||
test.AddInput<float>("B", dims, {1.0f, 1.0f, 2.0f, -1.0f});
|
||||
test.AddOutput<bool>("C", dims, {true, true, true, true});
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
|
||||
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
|
||||
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(MathOpTest, LessOrEqual_Scalar0) {
|
||||
|
|
@ -1644,7 +1659,7 @@ TEST(MathOpTest, LessOrEqual_Scalar0) {
|
|||
test.AddInput<float>("B", {4}, {1.0f, 1.5f, 2.0f, -1.0f});
|
||||
test.AddOutput<bool>("C", {4}, {true, true, true, false});
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
|
||||
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
|
||||
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(MathOpTest, LessOrEqual_Scalar1) {
|
||||
|
|
@ -1653,7 +1668,7 @@ TEST(MathOpTest, LessOrEqual_Scalar1) {
|
|||
test.AddInput<float>("B", {1}, {1.0f});
|
||||
test.AddOutput<bool>("C", {4}, {true, true, false, true});
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
|
||||
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
|
||||
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(MathOpTest, LessOrEqual_int64_Scalar1) {
|
||||
|
|
@ -1662,7 +1677,7 @@ TEST(MathOpTest, LessOrEqual_int64_Scalar1) {
|
|||
test.AddInput<int64_t>("B", {1}, {1});
|
||||
test.AddOutput<bool>("C", {4}, {true, true, false, true});
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
|
||||
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
|
||||
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
TEST(MathOpTest, LessOrEqual_broadcastAB) {
|
||||
OpTester test("LessOrEqual", 12);
|
||||
|
|
@ -1670,7 +1685,7 @@ TEST(MathOpTest, LessOrEqual_broadcastAB) {
|
|||
test.AddInput<int32_t>("B", {2}, {15, 7});
|
||||
test.AddOutput<bool>("C", {4, 2}, {true, false, true, false, true, false, false, false});
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
|
||||
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
|
||||
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(MathOpTest, LessOrEqual_broadcastBA) {
|
||||
|
|
@ -1679,7 +1694,7 @@ TEST(MathOpTest, LessOrEqual_broadcastBA) {
|
|||
test.AddInput<int32_t>("B", {4, 2}, {10, 11, 12, 13, 14, 15, 16, 17});
|
||||
test.AddOutput<bool>("C", {4, 2}, {false, true, false, true, false, true, true, true});
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
|
||||
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
|
||||
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(MathOpTest, LessOrEqual_multidiretional_broadcastAB) {
|
||||
|
|
@ -1688,7 +1703,7 @@ TEST(MathOpTest, LessOrEqual_multidiretional_broadcastAB) {
|
|||
test.AddInput<int32_t>("B", {2}, {15, 7});
|
||||
test.AddOutput<bool>("C", {4, 2}, {true, false, true, false, true, false, true, false});
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
|
||||
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
|
||||
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(MathOpTest, LessOrEqual_multidiretional_broadcastBA) {
|
||||
|
|
@ -1697,7 +1712,7 @@ TEST(MathOpTest, LessOrEqual_multidiretional_broadcastBA) {
|
|||
test.AddInput<int32_t>("B", {4, 1}, {10, 11, 12, 13});
|
||||
test.AddOutput<bool>("C", {4, 2}, {false, true, false, true, false, true, false, true});
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
|
||||
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
|
||||
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Greater_7) {
|
||||
|
|
@ -1784,7 +1799,7 @@ TEST(MathOpTest, GreaterOrEqual_12_float) {
|
|||
test.AddInput<float>("B", dims, {1.0f, 1.0f, 2.0f, -1.0f});
|
||||
test.AddOutput<bool>("C", dims, {true, false, false, true});
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
|
||||
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
|
||||
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(MathOpTest, GreaterOrEqual_12_double) {
|
||||
|
|
@ -1794,7 +1809,7 @@ TEST(MathOpTest, GreaterOrEqual_12_double) {
|
|||
test.AddInput<double>("B", dims, {1.0, 1.0, 2.0, -1.0});
|
||||
test.AddOutput<bool>("C", dims, {true, false, true, true});
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
|
||||
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
|
||||
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(MathOpTest, GreaterOrEqual_12_int32) {
|
||||
|
|
@ -1804,7 +1819,7 @@ TEST(MathOpTest, GreaterOrEqual_12_int32) {
|
|||
test.AddInput<int32_t>("B", dims, {15, 7, 12, 9});
|
||||
test.AddOutput<bool>("C", dims, {false, true, true, true});
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
|
||||
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
|
||||
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(MathOpTest, GreaterOrEqual_12_int64) {
|
||||
|
|
@ -1814,7 +1829,7 @@ TEST(MathOpTest, GreaterOrEqual_12_int64) {
|
|||
test.AddInput<int64_t>("B", dims, {15, 7, 12, 9});
|
||||
test.AddOutput<bool>("C", dims, {false, true, true, true});
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
|
||||
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
|
||||
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(MathOpTest, GreaterOrEqual_broadcastAB) {
|
||||
|
|
@ -1823,7 +1838,7 @@ TEST(MathOpTest, GreaterOrEqual_broadcastAB) {
|
|||
test.AddInput<int32_t>("B", {2}, {15, 7});
|
||||
test.AddOutput<bool>("C", {4, 2}, {false, true, false, true, false, true, true, true});
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
|
||||
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
|
||||
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(MathOpTest, GreaterOrEqual_broadcastBA) {
|
||||
|
|
@ -1832,7 +1847,7 @@ TEST(MathOpTest, GreaterOrEqual_broadcastBA) {
|
|||
test.AddInput<int32_t>("B", {4, 2}, {10, 11, 12, 13, 14, 15, 16, 17});
|
||||
test.AddOutput<bool>("C", {4, 2}, {true, false, true, false, true, false, false, false});
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
|
||||
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
|
||||
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(MathOpTest, GreaterOrEqual_multidiretional_broadcastAB) {
|
||||
|
|
@ -1841,7 +1856,7 @@ TEST(MathOpTest, GreaterOrEqual_multidiretional_broadcastAB) {
|
|||
test.AddInput<int32_t>("B", {2}, {15, 7});
|
||||
test.AddOutput<bool>("C", {4, 2}, {false, true, false, true, false, true, false, true});
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
|
||||
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
|
||||
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(MathOpTest, GreaterOrEqual_multidiretional_broadcastBA) {
|
||||
|
|
@ -1850,7 +1865,7 @@ TEST(MathOpTest, GreaterOrEqual_multidiretional_broadcastBA) {
|
|||
test.AddInput<int32_t>("B", {4, 1}, {10, 11, 12, 13});
|
||||
test.AddOutput<bool>("C", {4, 2}, {true, false, true, false, true, false, true, false});
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
|
||||
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
|
||||
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Equal_bool) {
|
||||
|
|
|
|||
|
|
@ -790,5 +790,30 @@ TEST(BatchNormTest, ForwardTrainingTestOpset14) {
|
|||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(BatchNormTest, ForwardTrainingTestOpset15) {
|
||||
OpTester test("BatchNormalization", 15);
|
||||
float epsilon = 1e-05f;
|
||||
float momentum = 0.1f;
|
||||
int64_t training_mode = 1;
|
||||
test.AddAttribute("epsilon", epsilon);
|
||||
test.AddAttribute("momentum", momentum);
|
||||
test.AddAttribute("training_mode", training_mode);
|
||||
std::vector<int64_t> input_output_dims{2, 2, 2, 2};
|
||||
std::vector<int64_t> channel_dims{2};
|
||||
test.AddInput<float>("X", input_output_dims, {-0.2953f, 0.1180f, 1.0973f, -0.1931f, -0.1999f, -0.0237f, 1.5181f, 0.0076f, -1.0830f, -1.5433f, 0.4327f, -0.9813f, 0.7875f, -0.4080f, -2.3144f, 1.5493f});
|
||||
test.AddInput<float>("scale", channel_dims, {1.0f, 1.0f});
|
||||
test.AddInput<float>("B", channel_dims, {0.0f, 0.0f});
|
||||
test.AddInput<float>("mean", channel_dims, {1.0f, 2.0f});
|
||||
test.AddInput<float>("var", channel_dims, {1.0f, 2.0f});
|
||||
|
||||
test.AddOutput<float>("Y", input_output_dims, {0.0131f, 0.5210f, 1.7244f, 0.1387f, -0.2708f, -0.1191f, 1.2089f, -0.0922f, -0.9548f, -1.5203f, 0.9077f, -0.8298f, 0.5796f, -0.4501f, -2.0921f, 1.2358f});
|
||||
|
||||
test.AddOutput<float>("running_mean", channel_dims, {-0.1754f, 0.303106f});
|
||||
test.AddOutput<float>("running_var", channel_dims, {0.696052f, 1.41316f});
|
||||
|
||||
// Same exclusions as the opset 14 test
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider});
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -4,26 +4,75 @@
|
|||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
template<typename T>
|
||||
void TestShape(const std::initializer_list<T>& data, const std::vector<int64_t>& shape)
|
||||
{
|
||||
template <typename T>
|
||||
void TestShape(const std::initializer_list<T>& data, const std::vector<int64_t>& shape) {
|
||||
OpTester test("Shape");
|
||||
test.AddInput<T>("data", shape, data);
|
||||
test.AddOutput<int64_t>("output", {static_cast<int64_t>(shape.size())}, shape);
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});//TensorRT parser: unsupported data types
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT parser: unsupported data types
|
||||
}
|
||||
|
||||
TEST(ShapeOpTest, ShapeTestBool) { TestShape <bool> ({true, true, false, false, true, false}, {2, 3}); }
|
||||
TEST(ShapeOpTest, ShapeTestFloat) { TestShape <float> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {2, 6}); }
|
||||
TEST(ShapeOpTest, ShapeTestDouble) { TestShape <double> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {6, 2}); }
|
||||
TEST(ShapeOpTest, ShapeTestInt8) { TestShape <int8_t> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {3, 4}); }
|
||||
TEST(ShapeOpTest, ShapeTestInt16) { TestShape <int16_t> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {3, 4}); }
|
||||
TEST(ShapeOpTest, ShapeTestInt32) { TestShape <int32_t> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {4, 3}); }
|
||||
TEST(ShapeOpTest, ShapeTestInt64) { TestShape <int64_t> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 12}); }
|
||||
TEST(ShapeOpTest, ShapeTestUint8) { TestShape <uint8_t> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {12, 1}); }
|
||||
TEST(ShapeOpTest, ShapeTestUint16) { TestShape <uint16_t> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 12}); }
|
||||
TEST(ShapeOpTest, ShapeTestUint32) { TestShape <uint32_t> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {12, 1}); }
|
||||
TEST(ShapeOpTest, ShapeTestUint64) { TestShape <uint64_t> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 12}); }
|
||||
TEST(ShapeOpTest, ShapeTestBool) { TestShape<bool>({true, true, false, false, true, false}, {2, 3}); }
|
||||
TEST(ShapeOpTest, ShapeTestFloat) { TestShape<float>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {2, 6}); }
|
||||
TEST(ShapeOpTest, ShapeTestDouble) { TestShape<double>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {6, 2}); }
|
||||
TEST(ShapeOpTest, ShapeTestInt8) { TestShape<int8_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {3, 4}); }
|
||||
TEST(ShapeOpTest, ShapeTestInt16) { TestShape<int16_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {3, 4}); }
|
||||
TEST(ShapeOpTest, ShapeTestInt32) { TestShape<int32_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {4, 3}); }
|
||||
TEST(ShapeOpTest, ShapeTestInt64) { TestShape<int64_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 12}); }
|
||||
TEST(ShapeOpTest, ShapeTestUint8) { TestShape<uint8_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {12, 1}); }
|
||||
TEST(ShapeOpTest, ShapeTestUint16) { TestShape<uint16_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 12}); }
|
||||
TEST(ShapeOpTest, ShapeTestUint32) { TestShape<uint32_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {12, 1}); }
|
||||
TEST(ShapeOpTest, ShapeTestUint64) { TestShape<uint64_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 12}); }
|
||||
TEST(ShapeOpTest, ShapeTestString) { TestShape<std::string>({"1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"}, {1, 12}); }
|
||||
|
||||
TEST(ShapeOpTest, ShapeOpset15_Default) {
|
||||
OpTester test("Shape", 15);
|
||||
test.AddInput<int32_t>("data", {1, 2, 2}, {1, 2, 3, 4});
|
||||
test.AddOutput<int64_t>("output", {3}, {1, 2, 2});
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT parser: unsupported data types
|
||||
}
|
||||
|
||||
TEST(ShapeOpTest, ShapeOpset15_StartOnly) {
|
||||
OpTester test("Shape", 15);
|
||||
test.AddAttribute<int64_t>("start", 1);
|
||||
test.AddInput<int32_t>("data", {1, 2, 2}, {1, 2, 3, 4});
|
||||
test.AddOutput<int64_t>("output", {2}, {2, 2});
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT parser: unsupported data types
|
||||
}
|
||||
|
||||
TEST(ShapeOpTest, ShapeOpset15_EndOnly) {
|
||||
OpTester test("Shape", 15);
|
||||
test.AddAttribute<int64_t>("end", 2);
|
||||
test.AddInput<int32_t>("data", {1, 2, 2}, {1, 2, 3, 4});
|
||||
test.AddOutput<int64_t>("output", {2}, {1, 2});
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT parser: unsupported data types
|
||||
}
|
||||
|
||||
TEST(ShapeOpTest, ShapeOpset15_StartAndEnd) {
|
||||
OpTester test("Shape", 15);
|
||||
test.AddAttribute<int64_t>("start", 1);
|
||||
test.AddAttribute<int64_t>("end", 2);
|
||||
test.AddInput<int32_t>("data", {1, 2, 2}, {1, 2, 3, 4});
|
||||
test.AddOutput<int64_t>("output", {1}, {2});
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT parser: unsupported data types
|
||||
}
|
||||
|
||||
TEST(ShapeOpTest, ShapeOpset15_StartAndEndNegative) {
|
||||
OpTester test("Shape", 15);
|
||||
test.AddAttribute<int64_t>("start", -2);
|
||||
test.AddAttribute<int64_t>("end", -1);
|
||||
test.AddInput<int32_t>("data", {1, 2, 2}, {1, 2, 3, 4});
|
||||
test.AddOutput<int64_t>("output", {1}, {2});
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT parser: unsupported data types
|
||||
}
|
||||
TEST(ShapeOpTest, ShapeOpset15_StartAndEndProducingEmptySlice) {
|
||||
OpTester test("Shape", 15);
|
||||
test.AddAttribute<int64_t>("start", 2);
|
||||
test.AddAttribute<int64_t>("end", 2);
|
||||
test.AddInput<int32_t>("data", {1, 2, 2}, {1, 2, 3, 4});
|
||||
test.AddOutput<int64_t>("output", {0}, {});
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT parser: unsupported data types
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -255,6 +255,14 @@
|
|||
"BatchNormalization ai.onnx CPUExecutionProvider",
|
||||
17832136363477464736
|
||||
],
|
||||
[
|
||||
"BatchNormalization ai.onnx CPUExecutionProvider",
|
||||
3016597991190826984
|
||||
],
|
||||
[
|
||||
"BatchNormalization ai.onnx CPUExecutionProvider",
|
||||
9270095107043637928
|
||||
],
|
||||
[
|
||||
"BitShift ai.onnx CPUExecutionProvider",
|
||||
4758677670685660688
|
||||
|
|
@ -1483,6 +1491,10 @@
|
|||
"Pow ai.onnx CPUExecutionProvider",
|
||||
12963226513247425672
|
||||
],
|
||||
[
|
||||
"Pow ai.onnx CPUExecutionProvider",
|
||||
16138602580714332296
|
||||
],
|
||||
[
|
||||
"PRelu ai.onnx CPUExecutionProvider",
|
||||
3282999003886175808
|
||||
|
|
@ -2159,6 +2171,10 @@
|
|||
"Shape ai.onnx CPUExecutionProvider",
|
||||
14989007508280400584
|
||||
],
|
||||
[
|
||||
"Shape ai.onnx CPUExecutionProvider",
|
||||
9917761852037658112
|
||||
],
|
||||
[
|
||||
"Shrink ai.onnx CPUExecutionProvider",
|
||||
4706529740707835200
|
||||
|
|
|
|||
|
|
@ -67,7 +67,6 @@
|
|||
"^test_add_uint8_cpu",
|
||||
"^test_div_uint8_cpu",
|
||||
// Following tests are for opset 15 ops and are not yet implemented in ORT
|
||||
"^test_shape_*",
|
||||
"^test_optional_*",
|
||||
//GPU failures
|
||||
"^test_batchnorm_epsilon_training_mode_cuda",
|
||||
|
|
|
|||
|
|
@ -38,7 +38,6 @@
|
|||
#include "core/optimizer/relu_clip_fusion.h"
|
||||
#include "core/optimizer/reshape_fusion.h"
|
||||
#include "core/optimizer/rule_based_graph_transformer.h"
|
||||
#include "core/optimizer/shape_to_initializer.h"
|
||||
#include "core/optimizer/skip_layer_norm_fusion.h"
|
||||
#include "core/optimizer/slice_elimination.h"
|
||||
#include "core/optimizer/unsqueeze_elimination.h"
|
||||
|
|
@ -75,7 +74,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
|
|||
case TransformerLevel::Level1: {
|
||||
rule_transformer =
|
||||
std::make_unique<RuleBasedGraphTransformer>(optimizer_utils::GenerateRuleBasedTransformerName(level),
|
||||
compatible_eps);
|
||||
compatible_eps);
|
||||
rule_transformer->Register(std::make_unique<InsertMaxPoolOutput>());
|
||||
rule_transformer->Register(std::make_unique<BatchNormReplacement>());
|
||||
rule_transformer->Register(std::make_unique<UnsqueezeElimination>());
|
||||
|
|
@ -127,16 +126,16 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
|
|||
if (config.propagate_cast_ops_config.level >= 0) {
|
||||
std::unordered_set<std::string> cuda_execution_provider = {onnxruntime::kCudaExecutionProvider, onnxruntime::kRocmExecutionProvider};
|
||||
transformers.emplace_back(std::make_unique<PropagateCastOps>(config.propagate_cast_ops_config.strategy,
|
||||
static_cast<size_t>(config.propagate_cast_ops_config.level),
|
||||
config.propagate_cast_ops_config.allow,
|
||||
cuda_execution_provider));
|
||||
static_cast<size_t>(config.propagate_cast_ops_config.level),
|
||||
config.propagate_cast_ops_config.allow,
|
||||
cuda_execution_provider));
|
||||
}
|
||||
} break;
|
||||
|
||||
case TransformerLevel::Level2: {
|
||||
rule_transformer =
|
||||
std::make_unique<RuleBasedGraphTransformer>(optimizer_utils::GenerateRuleBasedTransformerName(level),
|
||||
compatible_eps);
|
||||
compatible_eps);
|
||||
rule_transformer->Register(std::make_unique<ConcatReplacement>());
|
||||
} break;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue