Add opset 15 kernels for Pow, BatchNorm, and Shape (#8442)

This commit is contained in:
Hariharan Seshadri 2021-08-25 12:04:20 -07:00 committed by GitHub
parent 33a97e995b
commit cee79526fd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
25 changed files with 471 additions and 276 deletions

View file

@ -38,7 +38,8 @@ Do not modify directly.*
|AveragePool|*in* X:**T**<br> *out* Y:**T**|11+|**T** = tensor(float)|
|||10|**T** = tensor(float)|
|||[7, 9]|**T** = tensor(float)|
|BatchNormalization|*in* X:**T**<br> *in* scale:**T**<br> *in* B:**T**<br> *in* input_mean:**U**<br> *in* input_var:**U**<br> *out* Y:**T**<br> *out* running_mean:**U**<br> *out* running_var:**U**<br><br>or<br><br>*in* X:**T**<br> *in* scale:**T**<br> *in* B:**T**<br> *in* mean:**T**<br> *in* var:**T**<br> *out* Y:**T**<br> *out* mean:**T**<br> *out* var:**T**<br> *out* saved_mean:**T**<br> *out* saved_var:**T**<br><br>or<br><br>*in* X:**T**<br> *in* scale:**T1**<br> *in* B:**T1**<br> *in* input_mean:**T2**<br> *in* input_var:**T2**<br> *out* Y:**T**<br> *out* running_mean:**T2**<br> *out* running_var:**T2**|14+|**T** = tensor(double), tensor(float)|
|BatchNormalization|*in* X:**T**<br> *in* scale:**T**<br> *in* B:**T**<br> *in* input_mean:**U**<br> *in* input_var:**U**<br> *out* Y:**T**<br> *out* running_mean:**U**<br> *out* running_var:**U**<br><br>or<br><br>*in* X:**T**<br> *in* scale:**T**<br> *in* B:**T**<br> *in* mean:**T**<br> *in* var:**T**<br> *out* Y:**T**<br> *out* mean:**T**<br> *out* var:**T**<br> *out* saved_mean:**T**<br> *out* saved_var:**T**<br><br>or<br><br>*in* X:**T**<br> *in* scale:**T1**<br> *in* B:**T1**<br> *in* input_mean:**T2**<br> *in* input_var:**T2**<br> *out* Y:**T**<br> *out* running_mean:**T2**<br> *out* running_var:**T2**|15+|**T** = tensor(double), tensor(float)<br/> **T1** = tensor(double), tensor(float)<br/> **T2** = tensor(double), tensor(float)|
|||14|**T** = tensor(double), tensor(float)<br/> **U** = tensor(double), tensor(float)|
|||[9, 13]|**T** = tensor(double), tensor(float)|
|||[7, 8]|**T** = tensor(double), tensor(float)|
|BitShift|*in* X:**T**<br> *in* Y:**T**<br> *out* Z:**T**|11+|**T** = tensor(uint32), tensor(uint64), tensor(uint8)|
@ -202,7 +203,8 @@ Do not modify directly.*
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[2, 10]|**T** = tensor(double), tensor(float)|
|ParametricSoftplus|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|Pow|*in* X:**T**<br> *in* Y:**T**<br> *out* Z:**T**<br><br>or<br><br>*in* X:**T**<br> *in* Y:**T1**<br> *out* Z:**T**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)<br/> **T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|Pow|*in* X:**T**<br> *in* Y:**T**<br> *out* Z:**T**<br><br>or<br><br>*in* X:**T**<br> *in* Y:**T1**<br> *out* Z:**T**|15+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)<br/> **T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|||[13, 14]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)<br/> **T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|||12|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)<br/> **T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|||[7, 11]|**T** = tensor(double), tensor(float)|
|QLinearConv|*in* x:**T1**<br> *in* x_scale:**tensor(float)**<br> *in* x_zero_point:**T1**<br> *in* w:**T2**<br> *in* w_scale:**tensor(float)**<br> *in* w_zero_point:**T2**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T3**<br> *in* B:**T4**<br> *out* y:**T3**|10+|**T1** = tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(uint8)<br/> **T4** = tensor(int32)|
@ -280,7 +282,8 @@ Do not modify directly.*
|SequenceErase|*in* input_sequence:**S**<br> *in* position:**I**<br> *out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
|SequenceInsert|*in* input_sequence:**S**<br> *in* tensor:**T**<br> *in* position:**I**<br> *out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
|SequenceLength|*in* input_sequence:**S**<br> *out* length:**I**|11+|**I** = tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
|Shape|*in* data:**T**<br> *out* shape:**T1**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|Shape|*in* data:**T**<br> *out* shape:**T1**|15+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|||[13, 14]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|Shrink|*in* input:**T**<br> *out* output:**T**|9+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Sigmoid|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float)|
@ -446,7 +449,8 @@ Do not modify directly.*
|AveragePool|*in* X:**T**<br> *out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)|
|||10|**I** = tensor(int64)<br/> **T** = tensor(double), tensor(float), tensor(float16)|
|||[7, 9]|**I** = tensor(int64)<br/> **T** = tensor(double), tensor(float), tensor(float16)|
|BatchNormalization|*in* X:**T**<br> *in* scale:**T**<br> *in* B:**T**<br> *in* input_mean:**U**<br> *in* input_var:**U**<br> *out* Y:**T**<br> *out* running_mean:**U**<br> *out* running_var:**U**<br><br>or<br><br>*in* X:**T**<br> *in* scale:**T**<br> *in* B:**T**<br> *in* mean:**T**<br> *in* var:**T**<br> *out* Y:**T**<br> *out* mean:**T**<br> *out* var:**T**<br> *out* saved_mean:**T**<br> *out* saved_var:**T**<br><br>or<br><br>*in* X:**T**<br> *in* scale:**T1**<br> *in* B:**T1**<br> *in* input_mean:**T2**<br> *in* input_var:**T2**<br> *out* Y:**T**<br> *out* running_mean:**T2**<br> *out* running_var:**T2**|14+|**T** = tensor(double), tensor(float), tensor(float16)|
|BatchNormalization|*in* X:**T**<br> *in* scale:**T**<br> *in* B:**T**<br> *in* input_mean:**U**<br> *in* input_var:**U**<br> *out* Y:**T**<br> *out* running_mean:**U**<br> *out* running_var:**U**<br><br>or<br><br>*in* X:**T**<br> *in* scale:**T**<br> *in* B:**T**<br> *in* mean:**T**<br> *in* var:**T**<br> *out* Y:**T**<br> *out* mean:**T**<br> *out* var:**T**<br> *out* saved_mean:**T**<br> *out* saved_var:**T**<br><br>or<br><br>*in* X:**T**<br> *in* scale:**T1**<br> *in* B:**T1**<br> *in* input_mean:**T2**<br> *in* input_var:**T2**<br> *out* Y:**T**<br> *out* running_mean:**T2**<br> *out* running_var:**T2**|15+|**T** = tensor(double), tensor(float), tensor(float16)<br/> **T1** = tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(double), tensor(float), tensor(float16)|
|||14|**T** = tensor(double), tensor(float), tensor(float16)<br/> **U** = tensor(double), tensor(float), tensor(float16)|
|||[9, 13]|**T** = tensor(double), tensor(float), tensor(float16)|
|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)|
|Cast|*in* input:**T1**<br> *out* output:**T2**|13+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
@ -582,7 +586,8 @@ Do not modify directly.*
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|||[2, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
|ParametricSoftplus|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Pow|*in* X:**T**<br> *in* Y:**T**<br> *out* Z:**T**<br><br>or<br><br>*in* X:**T**<br> *in* Y:**T1**<br> *out* Z:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)<br/> **T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
|Pow|*in* X:**T**<br> *in* Y:**T**<br> *out* Z:**T**<br><br>or<br><br>*in* X:**T**<br> *in* Y:**T1**<br> *out* Z:**T**|15+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)<br/> **T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
|||[13, 14]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)<br/> **T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)<br/> **T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
|||[7, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|10+|**T1** = tensor(float)<br/> **T2** = tensor(int8), tensor(uint8)|
@ -653,7 +658,8 @@ Do not modify directly.*
|SequenceErase|*in* input_sequence:**S**<br> *in* position:**I**<br> *out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
|SequenceInsert|*in* input_sequence:**S**<br> *in* tensor:**T**<br> *in* position:**I**<br> *out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
|SequenceLength|*in* input_sequence:**S**<br> *out* length:**I**|11+|**I** = tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
|Shape|*in* data:**T**<br> *out* shape:**T1**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|Shape|*in* data:**T**<br> *out* shape:**T1**|15+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|||[13, 14]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|Shrink|*in* input:**T**<br> *out* output:**T**|9+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Sigmoid|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|

View file

@ -72,6 +72,15 @@ class TensorShape : private std::vector<int64_t> {
memcpy(dims, data(), sizeof(value_type) * std::min(num_dims, NumDimensions()));
}
/**
Copy dims from a specific start dim into an array with given size
`start_dim` is expected to be in the inclusive range [0, NumDimensions() - 1]
and this function does no checks to ensure that
*/
void CopyDims(int64_t* dims, size_t start_dim, size_t num_dims) const {
memcpy(dims, data() + start_dim, sizeof(value_type) * std::min(num_dims, NumDimensions() - start_dim));
}
/**
Return underlying vector representation.
*/

View file

@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <limits>
#include "core/optimizer/constant_folding.h"
#include "core/optimizer/utils.h"
#include "core/graph/graph_utils.h"
@ -25,6 +27,20 @@ ConstantFolding::ConstantFolding(const IExecutionProvider& execution_provider,
// We need to handle a Shape node separately as the input doesn't need to be a constant initializer for
// Shape to be able to be constant folded.
static bool ConstantFoldShapeNode(Graph& graph, Node& node) {
// Opset-15 Shape supports slicing using a 'start' and 'end' attribute
const auto& shape_attributes = node.GetAttributes();
int64_t start = 0;
int64_t end = std::numeric_limits<int64_t>::max();
for (const auto& attr : shape_attributes) {
if (attr.first == "start") {
start = attr.second.i();
} else if (attr.first == "end") {
end = attr.second.i();
}
}
auto shape = node.MutableInputDefs()[0]->Shape();
bool is_concrete_shape = true;
std::vector<int64_t> dim_values;
@ -42,14 +58,30 @@ static bool ConstantFoldShapeNode(Graph& graph, Node& node) {
}
if (is_concrete_shape) {
int64_t rank = static_cast<int64_t>(dim_values.size());
// We ascertain the "true" starts/ends (if they were provided)
// Opset-15 Shape op supports slicing shape values
// Deal with negatives and clamp
start = start < 0 ? start + rank : start;
start = start < 0 ? 0 : ((start > rank) ? rank : start);
end = end < 0 ? end + rank : end;
end = end < 0 ? 0 : ((end > rank) ? rank : end);
int64_t slice_length = end - start;
size_t clamped_slice_length = slice_length < 0 ? 0 : static_cast<size_t>(slice_length);
ONNX_NAMESPACE::TensorProto shape_constant;
auto* constant_arg_out = node.MutableOutputDefs()[0];
shape_constant.set_name(constant_arg_out->Name());
shape_constant.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
shape_constant.add_dims(dim_values.size());
shape_constant.set_raw_data(dim_values.data(), dim_values.size() * sizeof(int64_t));
shape_constant.add_dims(clamped_slice_length);
shape_constant.set_raw_data(dim_values.data() + start,
clamped_slice_length * sizeof(int64_t));
ONNX_NAMESPACE::TensorShapeProto result_shape;
result_shape.add_dim()->set_dim_value(dim_values.size());
result_shape.add_dim()->set_dim_value(clamped_slice_length);
constant_arg_out->SetShape(result_shape);
graph.AddInitializedTensor(shape_constant);
}

View file

@ -37,7 +37,6 @@
#include "core/optimizer/relu_clip_fusion.h"
#include "core/optimizer/reshape_fusion.h"
#include "core/optimizer/rule_based_graph_transformer.h"
#include "core/optimizer/shape_to_initializer.h"
#include "core/optimizer/skip_layer_norm_fusion.h"
#include "core/optimizer/slice_elimination.h"
#include "core/optimizer/unsqueeze_elimination.h"
@ -75,7 +74,6 @@ std::vector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(
rules.push_back(std::make_unique<FuseReluClip>());
rules.push_back(std::make_unique<GemmTransposeFusion>());
rules.push_back(std::make_unique<NotWhereFusion>());
rules.push_back(std::make_unique<ShapeToInitializer>());
rules.push_back(std::make_unique<ConvAddFusion>());
rules.push_back(std::make_unique<ConvMulFusion>());
rules.push_back(std::make_unique<ConvBNFusion>());

View file

@ -1,80 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/optimizer/shape_to_initializer.h"
#include "core/graph/graph.h"
#include "core/graph/graph_utils.h"
#include "core/graph/op.h"
#include "core/optimizer/initializer.h"
#include "core/optimizer/optimizer_execution_frame.h"
#include "core/framework/op_kernel.h"
#include "core/framework/tensorprotoutils.h"
namespace onnxruntime {
Status ShapeToInitializer::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
// Store the statically inferred shape of the input to the Shape operator.
const ONNX_NAMESPACE::TensorShapeProto* input_shape_proto = node.InputDefs()[0]->Shape();
std::vector<int64_t> input_dims;
int num_dimensions = input_shape_proto->dim_size();
for (int i = 0; i < num_dimensions; i++) {
input_dims.push_back(gsl::narrow_cast<int64_t>(input_shape_proto->dim(i).dim_value()));
}
// Create the TensorProto that will be used as initializer in place of the Shape operator.
const auto* shape_out_def = node.OutputDefs()[0];
ONNX_NAMESPACE::TensorProto shape_initializer_proto;
shape_initializer_proto.set_name(shape_out_def->Name());
TensorShape tensor_shape({gsl::narrow_cast<int64_t>(num_dimensions)});
for (auto& dim : tensor_shape.GetDims()) {
shape_initializer_proto.add_dims(dim);
}
auto tensor_proto_data_type = shape_out_def->TypeAsProto()->tensor_type().elem_type();
shape_initializer_proto.set_data_type(tensor_proto_data_type);
// Here we expect little-endian format to set raw data of the TensorProto.
shape_initializer_proto.set_raw_data(input_dims.data(),
input_dims.size() * sizeof(decltype(input_dims)::value_type));
auto& new_node_arg = graph_utils::AddInitializer(graph, shape_initializer_proto);
if (graph_utils::ReplaceNodeWithInitializer(graph, node, new_node_arg)) {
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
}
return Status::OK();
}
bool ShapeToInitializer::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Shape", {1, 13})) {
return false;
}
// The shape of the input has to be statically known. Moreover, each dimension should have a meaningful value
// (the rule cannot be applied if one of the dimensions has a negative value or if it is a symbolic variable).
const auto* input_shape = node.InputDefs()[0]->Shape();
if (!input_shape) {
return false;
}
for (int i = 0, num_dims = input_shape->dim_size(); i < num_dims; i++) {
const auto& input_dim = input_shape->dim(i);
if (!utils::HasDimValue(input_dim) || input_dim.dim_value() < 0) {
return false;
}
}
// we're going to create an initializer with the same name as the node output
const auto& new_initializer_name = node.OutputDefs()[0]->Name();
if (!graph_utils::CanReplaceNodeWithInitializer(graph, node, new_initializer_name, logger)) {
return false;
}
return true;
}
} // namespace onnxruntime

View file

@ -1,32 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/optimizer/rewrite_rule.h"
namespace onnxruntime {
/**
@Class ShapeToInitializer
When the input to a Shape operator is statically known (through shape inference), this rule replaces the Shape node
with an initializer to the downstream nodes.
It is attempted to be triggered only on nodes with op type "Shape".
*/
class ShapeToInitializer : public RewriteRule {
public:
ShapeToInitializer() noexcept : RewriteRule("ShapeToInitializer") {}
std::vector<std::string> TargetOpTypes() const noexcept override {
return {"Shape"};
}
private:
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
};
} // namespace onnxruntime

View file

@ -527,7 +527,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, ArgMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, ArgMin);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 13, Reshape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Shape);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 14, Shape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Concat);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Less);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Less);
@ -590,7 +590,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Exp);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Log);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Log);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Pow);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 14, Pow);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, DepthToSpace);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, SpaceToDepth);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Slice);
@ -686,12 +686,17 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, int64_t, Div);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, Reshape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, Identity);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, float, BatchNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, double, BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, GRU);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, LSTM);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, RNN);
// Opset 15
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 15, Pow);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 15, float, BatchNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 15, double, BatchNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 15, Shape);
// !!PLEASE READ BELOW!! Following that, add new entries above this comment
@ -1151,9 +1156,9 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int64_t,
MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 13, float,
BatchNormalization)>,
BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 13, double,
BatchNormalization)>,
BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, PRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 9,
float, Upsample)>,
@ -1547,7 +1552,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
int32_t, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 13,
Reshape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Shape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 14, Shape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Concat)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, bool,
Equal)>,
@ -1650,7 +1655,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Exp)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Log)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Log)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Pow)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 14, Pow)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Split)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Unsqueeze)>,
@ -1810,13 +1815,22 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
Div)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, Reshape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, Identity)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, float,
BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, double,
BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, 14, float,
BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, 14, double,
BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, GRU)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, LSTM)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, RNN)>,
// Opset 15
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 15, Pow)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 15, float,
BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 15, double,
BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 15, Shape)>,
};
for (auto& function_table_entry : function_table) {

View file

@ -259,13 +259,27 @@ REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Sqrt, 6, 12, double, Sqrt);
REG_ELEMENTWISE_TYPED_KERNEL(Sqrt, 13, float, Sqrt);
REG_ELEMENTWISE_TYPED_KERNEL(Sqrt, 13, double, Sqrt);
REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Pow, 7, 11, Pow, BuildKernelDefConstraintsFromTypeList<Pow7Types>(), BuildKernelDefConstraintsFromTypeList<EnabledPow7Types>());
REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Pow, 7, 11, Pow,
BuildKernelDefConstraintsFromTypeList<Pow7Types>(),
BuildKernelDefConstraintsFromTypeList<EnabledPow7Types>());
REG_ELEMENTWISE_VERSIONED_KERNEL_NONT_2(Pow, 12, 12, Pow,
BuildKernelDefConstraintsFromTypeList<Pow12BaseTypes>(), BuildKernelDefConstraintsFromTypeList<EnabledPow12BaseTypes>(),
BuildKernelDefConstraintsFromTypeList<Pow12ExpTypes>(), BuildKernelDefConstraintsFromTypeList<EnabledPow12ExpTypes>());
REG_ELEMENTWISE_KERNEL_NONT_2(Pow, 13, Pow,
BuildKernelDefConstraintsFromTypeList<Pow12BaseTypes>(), BuildKernelDefConstraintsFromTypeList<EnabledPow12BaseTypes>(),
BuildKernelDefConstraintsFromTypeList<Pow12ExpTypes>(), BuildKernelDefConstraintsFromTypeList<EnabledPow12ExpTypes>());
BuildKernelDefConstraintsFromTypeList<Pow12BaseTypes>(),
BuildKernelDefConstraintsFromTypeList<EnabledPow12BaseTypes>(),
BuildKernelDefConstraintsFromTypeList<Pow12ExpTypes>(),
BuildKernelDefConstraintsFromTypeList<EnabledPow12ExpTypes>());
REG_ELEMENTWISE_VERSIONED_KERNEL_NONT_2(Pow, 13, 14, Pow,
BuildKernelDefConstraintsFromTypeList<Pow12BaseTypes>(),
BuildKernelDefConstraintsFromTypeList<EnabledPow12BaseTypes>(),
BuildKernelDefConstraintsFromTypeList<Pow12ExpTypes>(),
BuildKernelDefConstraintsFromTypeList<EnabledPow12ExpTypes>());
REG_ELEMENTWISE_KERNEL_NONT_2(Pow, 15, Pow,
BuildKernelDefConstraintsFromTypeList<Pow12BaseTypes>(),
BuildKernelDefConstraintsFromTypeList<EnabledPow12BaseTypes>(),
BuildKernelDefConstraintsFromTypeList<Pow12ExpTypes>(),
BuildKernelDefConstraintsFromTypeList<EnabledPow12ExpTypes>());
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Exp, 6, 12, float, Exp);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Exp, 6, 12, double, Exp);

View file

@ -31,18 +31,52 @@ ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(BatchNormalization, 7, 8, double,
// We alias the running mean to the mean so it stays preserved across multiple batches
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(BatchNormalization, 9, 13, float,
KernelDefBuilder().Alias(3,1).Alias(4,2).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
BatchNorm<float>);
KernelDefBuilder().Alias(3, 1).Alias(4, 2).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
BatchNorm<float>);
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(BatchNormalization, 9, 13, double,
KernelDefBuilder().Alias(3,1).Alias(4,2).TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
BatchNorm<double>);
KernelDefBuilder().Alias(3, 1).Alias(4, 2).TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
BatchNorm<double>);
ONNX_CPU_OPERATOR_TYPED_KERNEL(BatchNormalization, 14, float,
KernelDefBuilder().Alias(3,1).Alias(4,2).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(BatchNormalization, 14, 14, float,
KernelDefBuilder()
.Alias(3, 1)
.Alias(4, 2)
// ORT 1.8 was shipped with just the "T" type constraint and
// we want to maintain backwards compatibility for
// the hash and hence just use "T" for the hash generation
.FixedTypeConstraintForHash("T", {DataTypeImpl::GetTensorType<float>()})
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("U", DataTypeImpl::GetTensorType<float>()),
BatchNorm<float>);
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(BatchNormalization, 14, 14, double,
KernelDefBuilder()
.Alias(3, 1)
.Alias(4, 2)
// ORT 1.8 was shipped with just the "T" type constraint and
// we want to maintain backwards compatibility for
// the hash and hence just use "T" for the hash generation
.FixedTypeConstraintForHash("T", {DataTypeImpl::GetTensorType<double>()})
.TypeConstraint("T", DataTypeImpl::GetTensorType<double>())
.TypeConstraint("U", DataTypeImpl::GetTensorType<double>()),
BatchNorm<double>);
ONNX_CPU_OPERATOR_TYPED_KERNEL(BatchNormalization, 15, float,
KernelDefBuilder()
.Alias(3, 1)
.Alias(4, 2)
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("T1", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<float>()),
BatchNorm<float>);
ONNX_CPU_OPERATOR_TYPED_KERNEL(BatchNormalization, 14, double,
KernelDefBuilder().Alias(3,1).Alias(4,2).TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
ONNX_CPU_OPERATOR_TYPED_KERNEL(BatchNormalization, 15, double,
KernelDefBuilder()
.Alias(3, 1)
.Alias(4, 2)
.TypeConstraint("T", DataTypeImpl::GetTensorType<double>())
.TypeConstraint("T1", DataTypeImpl::GetTensorType<double>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<double>()),
BatchNorm<double>);
} // namespace onnxruntime

View file

@ -12,9 +12,16 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorTypes()).TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>()),
Shape);
ONNX_CPU_OPERATOR_KERNEL(
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
Shape,
13,
13, 14,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorTypes()).TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>()),
Shape);
ONNX_CPU_OPERATOR_KERNEL(
Shape,
15,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorTypes()).TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>()),
Shape);
} // namespace onnxruntime

View file

@ -9,25 +9,62 @@
#endif
#include "gsl/gsl"
#include <limits>
namespace onnxruntime {
class Shape final : public OpKernel {
public:
Shape(const OpKernelInfo& info) : OpKernel(info) {
info.GetAttrOrDefault<int64_t>("start", &start_index_, 0);
if (start_index_ != 0) {
// "start" is provided and is non-default (default is 0)
needs_slicing_ = true;
}
if (info.GetAttr<int64_t>("end", &end_index_).IsOK()) {
needs_slicing_ = true;
}
}
// Takes a tensor as input and outputs an 1D int64 tensor
// containing the shape of the input tensor.
Status Compute(OpKernelContext* context) const override {
const auto* input = context->Input<Tensor>(0);
const TensorShape& inputShape = input->Shape();
const TensorShape& input_shape = input->Shape();
size_t nDims = inputShape.NumDimensions();
Tensor* output = context->Output(0, {gsl::narrow_cast<int64_t>(nDims)});
int64_t rank = gsl::narrow_cast<int64_t>(input_shape.NumDimensions());
if (!needs_slicing_) { // vanilla use of Shape (no slicing)
Tensor* output = context->Output(0, {rank});
input_shape.CopyDims(output->template MutableData<int64_t>(), static_cast<size_t>(rank));
} else { // slicing is needed
int64_t true_start = start_index_;
int64_t true_end = end_index_;
// Deal with negative(s) and clamp
true_start = true_start < 0 ? true_start + rank : true_start;
true_start = true_start < 0 ? 0 : ((true_start > rank) ? rank : true_start);
true_end = true_end < 0 ? true_end + rank : true_end;
true_end = true_end < 0 ? 0 : ((true_end > rank) ? rank : true_end);
auto slice_length = true_end - true_start;
Tensor* output = context->Output(0, {slice_length < 0 ? 0 : slice_length});
if (slice_length > 0) {
input_shape.CopyDims(output->template MutableData<int64_t>(), true_start, slice_length);
}
}
inputShape.CopyDims(output->template MutableData<int64_t>(), nDims);
return Status::OK();
}
private:
bool needs_slicing_ = false;
int64_t start_index_ = 0;
int64_t end_index_ = std::numeric_limits<int64_t>::max();
};
} //namespace onnxruntime

View file

@ -892,7 +892,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, Einsum);
//OpSet 13
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Pow);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Pow);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int32_t, Add);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int64_t, Add);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, uint32_t, Add);
@ -1005,7 +1005,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint64_t, Cast);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, bool, Cast);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, Reshape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Shape);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Shape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Size);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Transpose);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, ScatterElements);
@ -1163,9 +1163,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, LSTM);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, LSTM);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, LSTM);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, BatchNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, BatchNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, ReduceMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, ReduceMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, ReduceMin);
@ -1182,6 +1182,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Relu);
#endif
//OpSet 15
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, Pow);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, float, BatchNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, double, BatchNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, MLFloat16, BatchNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, Shape);
template <>
KernelCreateInfo BuildKernelCreateInfo<void>() {
return {};
@ -1728,7 +1735,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, Einsum)>,
// OpSet 13
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Pow)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Pow)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int32_t, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int64_t, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, uint32_t, Add)>,
@ -1841,7 +1848,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint64_t, Cast)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, bool, Cast)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, Reshape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Shape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Shape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Size)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Transpose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, ScatterElements)>,
@ -1998,9 +2005,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, LSTM)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, LSTM)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, Reshape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, ReduceMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, ReduceMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, ReduceMin)>,
@ -2015,6 +2022,14 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Div)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Relu)>,
#endif
// OpSet 15
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, Pow)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, float, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, double, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, MLFloat16, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, Shape)>,
};
for (auto& function_table_entry : function_table) {

View file

@ -205,9 +205,9 @@ Status BinaryElementwise<ShouldBroadcast>::Prepare(OpKernelContext* context, Bin
#define BINARY_OP_TYPED_VERSIONED_V_BF16(name, class_name, startver, endver)
#endif
#define BINARY_OP_VERSIONED_HFD(name, startver, endver) \
BINARY_OP_VERSIONED_TYPED(name, startver, endver, MLFloat16) \
BINARY_OP_VERSIONED_TYPED(name, startver, endver, float) \
#define BINARY_OP_VERSIONED_HFD(name, startver, endver) \
BINARY_OP_VERSIONED_TYPED(name, startver, endver, MLFloat16) \
BINARY_OP_VERSIONED_TYPED(name, startver, endver, float) \
BINARY_OP_VERSIONED_TYPED(name, startver, endver, double)
#define BINARY_OP_VERSIONED_UZILHFD(name, startver, endver) \
@ -318,15 +318,29 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
kOnnxDomain,
12, 12,
kCudaExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints<int32_t, int64_t, float, double, MLFloat16>()).TypeConstraint("T1", BuildKernelDefConstraints<int32_t, int64_t, float, double, MLFloat16>()),
(*KernelDefBuilder::Create())
.TypeConstraint("T", BuildKernelDefConstraints<int32_t, int64_t, float, double, MLFloat16>())
.TypeConstraint("T1", BuildKernelDefConstraints<int32_t, int64_t, float, double, MLFloat16>()),
Pow);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Pow,
kOnnxDomain,
13, 14,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", BuildKernelDefConstraints<int32_t, int64_t, float, double, MLFloat16>())
.TypeConstraint("T1", BuildKernelDefConstraints<int32_t, int64_t, float, double, MLFloat16>()),
Pow);
ONNX_OPERATOR_KERNEL_EX(
Pow,
kOnnxDomain,
13,
15,
kCudaExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints<int32_t, int64_t, float, double, MLFloat16>()).TypeConstraint("T1", BuildKernelDefConstraints<int32_t, int64_t, float, double, MLFloat16>()),
(*KernelDefBuilder::Create())
.TypeConstraint("T", BuildKernelDefConstraints<int32_t, int64_t, float, double, MLFloat16>())
.TypeConstraint("T1", BuildKernelDefConstraints<int32_t, int64_t, float, double, MLFloat16>()),
Pow);
namespace pow12_internal {
@ -524,6 +538,5 @@ BINARY_OP_REGISTER_VERSIONED_HFD(Less, 7, 8)
BINARY_LOGICALOP_REGISTER_UZILHFD(GreaterOrEqual, 12)
BINARY_LOGICALOP_REGISTER_UZILHFD(LessOrEqual, 12)
} // namespace cuda
} // namespace onnxruntime

View file

@ -11,33 +11,45 @@ using namespace std;
namespace onnxruntime {
namespace cuda {
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
BatchNormalization, \
kOnnxDomain, \
7, 8, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
BatchNorm<T>); \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
BatchNormalization, \
kOnnxDomain, \
9, 13, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
BatchNorm<T>); \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
BatchNormalization, \
kOnnxDomain, \
14, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
BatchNormalization, \
kOnnxDomain, \
7, 8, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
BatchNorm<T>); \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
BatchNormalization, \
kOnnxDomain, \
9, 13, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
BatchNorm<T>); \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
BatchNormalization, \
kOnnxDomain, \
14, 14, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("U", DataTypeImpl::GetTensorType<T>()), \
BatchNorm<T>); \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
BatchNormalization, \
kOnnxDomain, \
15, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("T1", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("T2", DataTypeImpl::GetTensorType<T>()), \
BatchNorm<T>);
template <typename T>

View file

@ -39,8 +39,8 @@ class BatchNorm final : public CudaKernel {
const auto& node = op_kernel_info.node();
auto opset = node.SinceVersion();
// batch norm opset 14 is not implemented for training mode
ORT_ENFORCE(!(is_training_mode_ && opset==14), "Training mode does not support BN opset 14 yet.");
// batch norm opset 14 (or higher) is not implemented for training mode
ORT_ENFORCE(!(is_training_mode_ && opset >= 14), "Training mode does not support BN opset 14 (or higher) yet.");
}
Status ComputeInternal(OpKernelContext* context) const override;
@ -50,7 +50,7 @@ class BatchNorm final : public CudaKernel {
int64_t spatial_ = 1; // default as per spec
cudnnBatchNormMode_t cudnn_batch_norm_mode_;
double momentum_;
bool is_training_mode_ = 0; //default as per spec
bool is_training_mode_ = 0; //default as per spec
};
} // namespace cuda

View file

@ -20,10 +20,22 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>()),
Shape);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Shape,
kOnnxDomain,
13, 14,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
// properly force CPU/GPU synch inside the kernel
.OutputMemoryType(OrtMemTypeCPUInput, 0)
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>()),
Shape);
ONNX_OPERATOR_KERNEL_EX(
Shape,
kOnnxDomain,
13,
15,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
// properly force CPU/GPU synch inside the kernel

View file

@ -827,7 +827,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDom
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, Einsum);
//OpSet 13
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Pow);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 14, Pow);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, int32_t, Add);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, int64_t, Add);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, uint32_t, Add);
@ -934,7 +934,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint64_t, Cast);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, bool, Cast);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, Reshape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Shape);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 14, Shape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Size);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Transpose);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, ScatterElements);
@ -1029,6 +1029,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Pad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Pad);
// opset 15
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 15, Shape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 15, Pow);
template <>
KernelCreateInfo BuildKernelCreateInfo<void>() {
KernelCreateInfo info;
@ -1555,7 +1559,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, Einsum)>,
// OpSet 13
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Pow)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 14, Pow)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, int32_t, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, int64_t, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, uint32_t, Add)>,
@ -1662,7 +1666,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint64_t, Cast)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, bool, Cast)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, Reshape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Shape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 14, Shape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Size)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Transpose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, ScatterElements)>,
@ -1756,6 +1760,10 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Pad)>,
// opset 15
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 15, Shape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 15, Pow)>,
};
for (auto& function_table_entry : function_table) {

View file

@ -291,7 +291,8 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
double per_sample_tolerance = 1e-3;
// when cuda is enabled, set it to a larger value for resolving random MNIST test failure
// when openvino is enabled, set it to a larger value for resolving MNIST accuracy mismatch
double relative_per_sample_tolerance = enable_cuda ? 0.017 : enable_openvino ? 0.009 : 1e-3;
double relative_per_sample_tolerance = enable_cuda ? 0.017 : enable_openvino ? 0.009
: 1e-3;
Ort::SessionOptions sf;
@ -480,8 +481,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
ORT_TSTR("operator_pow"),
ORT_TSTR("bernoulli"),
ORT_TSTR("bernoulli_double"),
ORT_TSTR("bernoulli_seed")
};
ORT_TSTR("bernoulli_seed")};
static const ORTCHAR_T* cuda_flaky_tests[] = {
ORT_TSTR("fp16_inception_v1"),
@ -600,16 +600,6 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
{"bernoulli_seed", "By design. Test data is for informational purpose because the generator is non deterministic."},
{"bernoulli_seed_expanded", "By design. Test data is for informational purpose because the generator is non deterministic."},
{"bernoulli_expanded", "By design. Test data is for informational purpose because the generator is non deterministic."},
{"shape", "opset15 updates not supported yet."},
{"shape_clip_end", "opset15 updates not supported yet."},
{"shape_clip_start", "opset15 updates not supported yet."},
{"shape_end_1", "opset15 updates not supported yet."},
{"shape_end_negative_1", "opset15 updates not supported yet."},
{"shape_example", "opset15 updates not supported yet."},
{"shape_start_1", "opset15 updates not supported yet."},
{"shape_start_1_end_2", "opset15 updates not supported yet."},
{"shape_start_1_end_negative_1", "opset15 updates not supported yet."},
{"shape_start_negative_1", "opset15 updates not supported yet."},
{"test_optional_get_element", "opset15 updates not supported yet."},
{"test_optional_get_element_sequence", "opset15 updates not supported yet."},
{"test_optional_has_element", "opset15 updates not supported yet."},

View file

@ -53,7 +53,6 @@
#include "core/optimizer/relu_clip_fusion.h"
#include "core/optimizer/reshape_fusion.h"
#include "core/optimizer/rule_based_graph_transformer.h"
#include "core/optimizer/shape_to_initializer.h"
#include "core/optimizer/skip_layer_norm_fusion.h"
#include "core/optimizer/slice_elimination.h"
#include "core/optimizer/unsqueeze_elimination.h"
@ -480,7 +479,7 @@ TEST_F(GraphTransformationTests, ConstantFolding_RemoveDanglingInputNodesToConst
ASSERT_TRUE(op_to_count["RandomUniform"] == 0);
}
TEST_F(GraphTransformationTests, ShapeToInitializer) {
TEST_F(GraphTransformationTests, ConstantFoldingAShapeNodeDeepInTheGraph) {
auto model_uri = MODEL_FOLDER "shape-add.onnx";
std::shared_ptr<Model> model;
ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_));
@ -489,17 +488,21 @@ TEST_F(GraphTransformationTests, ShapeToInitializer) {
ASSERT_TRUE(op_to_count["Shape"] == 4);
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
rule_transformer_L1->Register(std::make_unique<ShapeToInitializer>());
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
std::unique_ptr<CPUExecutionProvider> e =
std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
graph_transformation_mgr.Register(std::make_unique<ConstantFolding>(*e.get(),
false /*skip_dequantize_linear*/),
TransformerLevel::Level1);
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
op_to_count = CountOpsInGraph(graph);
// Two of the Shapes are not eliminated because:
// One includes a symbolic dimension.
// Another one includes a negative dimension
ASSERT_TRUE(op_to_count["Shape"] == 2);
// A Shape node very deep in the graph (feeding into an Identity
// node that produces the graph output) gets constant folded which
// removes all its ancestors and the Identity node consuming this Shape's
// output is subsequently constant folded to leave the graph with no
// nodes.
ASSERT_TRUE(op_to_count.size() == 0);
}
// Check transformations in the case of a subgraph with constant inputs.
@ -674,8 +677,8 @@ TEST_F(GraphTransformationTests, FuseCudaConvAddRelu) {
graph_transformation_mgr.Register(std::make_unique<ConvActivationFusion>(), TransformerLevel::Level2);
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Add"] == 0); //Add removed from graph
ASSERT_TRUE(op_to_count["Relu"] == 0); //Relu removed from graph
ASSERT_TRUE(op_to_count["Add"] == 0); //Add removed from graph
ASSERT_TRUE(op_to_count["Relu"] == 0); //Relu removed from graph
}
//Conv->Add->Relu will be left intact since there is Identity depend on Add
@ -695,9 +698,9 @@ TEST_F(GraphTransformationTests, FuseCudaConvAddReluIdentity) {
graph_transformation_mgr.Register(std::make_unique<ConvActivationFusion>(), TransformerLevel::Level2);
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Add"] == 1); //Add remains
ASSERT_TRUE(op_to_count["Relu"] == 1); //Relu remains
ASSERT_TRUE(op_to_count["Identity"] == 1); //Identity remains
ASSERT_TRUE(op_to_count["Add"] == 1); //Add remains
ASSERT_TRUE(op_to_count["Relu"] == 1); //Relu remains
ASSERT_TRUE(op_to_count["Identity"] == 1); //Identity remains
}
//Conv->Add will be left intact since there is no Relu follows
@ -715,7 +718,7 @@ TEST_F(GraphTransformationTests, FuseCudaConvAdd) {
graph_transformation_mgr.Register(std::make_unique<ConvActivationFusion>(), TransformerLevel::Level2);
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Add"] == 1); //Add remains, no transform applied to the graph
ASSERT_TRUE(op_to_count["Add"] == 1); //Add remains, no transform applied to the graph
}
#endif
@ -4131,13 +4134,13 @@ TEST_F(GraphTransformationTests, FilterEnabledOptimizers) {
const auto& graph = session_object.GetGraph();
// check the ops that should go away if the constant folding transformer or ShapeToInitializer rewrite rule run
// check the ops that should go away if the constant folding transformer runs
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Shape"] == 1);
ASSERT_TRUE(op_to_count["ConstantOfShape"] == 1);
ASSERT_TRUE(op_to_count["Add"] == 1);
ASSERT_STATUS_OK(session_object.FilterEnabledOptimizers({"ConstantFolding", "ShapeToInitializer"}));
ASSERT_STATUS_OK(session_object.FilterEnabledOptimizers({"ConstantFolding"}));
ASSERT_STATUS_OK(session_object.Initialize()); // Initialize runs the transformers
op_to_count = CountOpsInGraph(graph);

View file

@ -688,6 +688,21 @@ TEST(MathOpTest, Pow_Float_12) {
test.Run();
}
TEST(MathOpTest, Pow_Float_15) {
OpTester test("Pow", 15);
std::vector<int64_t> dims{2, 2};
test.AddInput<float>("X", dims,
{2.0f, 2.0f,
std::sqrt(2.0f), 1.0f});
test.AddInput<float>("Y", dims,
{0.0f, 8.0f,
2.0f, 9.0f});
test.AddOutput<float>("Z", dims,
{1.0f, 256.0f,
2.0f, 1.0f});
test.Run();
}
TEST(MathOpTest, Pow_Double_12) {
OpTester test("Pow", 12);
std::vector<int64_t> dims{2, 2};
@ -1635,7 +1650,7 @@ TEST(MathOpTest, LessOrEqual) {
test.AddInput<float>("B", dims, {1.0f, 1.0f, 2.0f, -1.0f});
test.AddOutput<bool>("C", dims, {true, true, true, true});
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
}
TEST(MathOpTest, LessOrEqual_Scalar0) {
@ -1644,7 +1659,7 @@ TEST(MathOpTest, LessOrEqual_Scalar0) {
test.AddInput<float>("B", {4}, {1.0f, 1.5f, 2.0f, -1.0f});
test.AddOutput<bool>("C", {4}, {true, true, true, false});
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
}
TEST(MathOpTest, LessOrEqual_Scalar1) {
@ -1653,7 +1668,7 @@ TEST(MathOpTest, LessOrEqual_Scalar1) {
test.AddInput<float>("B", {1}, {1.0f});
test.AddOutput<bool>("C", {4}, {true, true, false, true});
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
}
TEST(MathOpTest, LessOrEqual_int64_Scalar1) {
@ -1662,7 +1677,7 @@ TEST(MathOpTest, LessOrEqual_int64_Scalar1) {
test.AddInput<int64_t>("B", {1}, {1});
test.AddOutput<bool>("C", {4}, {true, true, false, true});
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
}
TEST(MathOpTest, LessOrEqual_broadcastAB) {
OpTester test("LessOrEqual", 12);
@ -1670,7 +1685,7 @@ TEST(MathOpTest, LessOrEqual_broadcastAB) {
test.AddInput<int32_t>("B", {2}, {15, 7});
test.AddOutput<bool>("C", {4, 2}, {true, false, true, false, true, false, false, false});
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
}
TEST(MathOpTest, LessOrEqual_broadcastBA) {
@ -1679,7 +1694,7 @@ TEST(MathOpTest, LessOrEqual_broadcastBA) {
test.AddInput<int32_t>("B", {4, 2}, {10, 11, 12, 13, 14, 15, 16, 17});
test.AddOutput<bool>("C", {4, 2}, {false, true, false, true, false, true, true, true});
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
}
TEST(MathOpTest, LessOrEqual_multidiretional_broadcastAB) {
@ -1688,7 +1703,7 @@ TEST(MathOpTest, LessOrEqual_multidiretional_broadcastAB) {
test.AddInput<int32_t>("B", {2}, {15, 7});
test.AddOutput<bool>("C", {4, 2}, {true, false, true, false, true, false, true, false});
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
}
TEST(MathOpTest, LessOrEqual_multidiretional_broadcastBA) {
@ -1697,7 +1712,7 @@ TEST(MathOpTest, LessOrEqual_multidiretional_broadcastBA) {
test.AddInput<int32_t>("B", {4, 1}, {10, 11, 12, 13});
test.AddOutput<bool>("C", {4, 2}, {false, true, false, true, false, true, false, true});
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
}
TEST(MathOpTest, Greater_7) {
@ -1784,7 +1799,7 @@ TEST(MathOpTest, GreaterOrEqual_12_float) {
test.AddInput<float>("B", dims, {1.0f, 1.0f, 2.0f, -1.0f});
test.AddOutput<bool>("C", dims, {true, false, false, true});
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
}
TEST(MathOpTest, GreaterOrEqual_12_double) {
@ -1794,7 +1809,7 @@ TEST(MathOpTest, GreaterOrEqual_12_double) {
test.AddInput<double>("B", dims, {1.0, 1.0, 2.0, -1.0});
test.AddOutput<bool>("C", dims, {true, false, true, true});
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
}
TEST(MathOpTest, GreaterOrEqual_12_int32) {
@ -1804,7 +1819,7 @@ TEST(MathOpTest, GreaterOrEqual_12_int32) {
test.AddInput<int32_t>("B", dims, {15, 7, 12, 9});
test.AddOutput<bool>("C", dims, {false, true, true, true});
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
}
TEST(MathOpTest, GreaterOrEqual_12_int64) {
@ -1814,7 +1829,7 @@ TEST(MathOpTest, GreaterOrEqual_12_int64) {
test.AddInput<int64_t>("B", dims, {15, 7, 12, 9});
test.AddOutput<bool>("C", dims, {false, true, true, true});
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
}
TEST(MathOpTest, GreaterOrEqual_broadcastAB) {
@ -1823,7 +1838,7 @@ TEST(MathOpTest, GreaterOrEqual_broadcastAB) {
test.AddInput<int32_t>("B", {2}, {15, 7});
test.AddOutput<bool>("C", {4, 2}, {false, true, false, true, false, true, true, true});
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
}
TEST(MathOpTest, GreaterOrEqual_broadcastBA) {
@ -1832,7 +1847,7 @@ TEST(MathOpTest, GreaterOrEqual_broadcastBA) {
test.AddInput<int32_t>("B", {4, 2}, {10, 11, 12, 13, 14, 15, 16, 17});
test.AddOutput<bool>("C", {4, 2}, {true, false, true, false, true, false, false, false});
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
}
TEST(MathOpTest, GreaterOrEqual_multidiretional_broadcastAB) {
@ -1841,7 +1856,7 @@ TEST(MathOpTest, GreaterOrEqual_multidiretional_broadcastAB) {
test.AddInput<int32_t>("B", {2}, {15, 7});
test.AddOutput<bool>("C", {4, 2}, {false, true, false, true, false, true, false, true});
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
}
TEST(MathOpTest, GreaterOrEqual_multidiretional_broadcastBA) {
@ -1850,7 +1865,7 @@ TEST(MathOpTest, GreaterOrEqual_multidiretional_broadcastBA) {
test.AddInput<int32_t>("B", {4, 1}, {10, 11, 12, 13});
test.AddOutput<bool>("C", {4, 2}, {true, false, true, false, true, false, true, false});
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
{kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider});
}
TEST(MathOpTest, Equal_bool) {

View file

@ -790,5 +790,30 @@ TEST(BatchNormTest, ForwardTrainingTestOpset14) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider});
}
TEST(BatchNormTest, ForwardTrainingTestOpset15) {
OpTester test("BatchNormalization", 15);
float epsilon = 1e-05f;
float momentum = 0.1f;
int64_t training_mode = 1;
test.AddAttribute("epsilon", epsilon);
test.AddAttribute("momentum", momentum);
test.AddAttribute("training_mode", training_mode);
std::vector<int64_t> input_output_dims{2, 2, 2, 2};
std::vector<int64_t> channel_dims{2};
test.AddInput<float>("X", input_output_dims, {-0.2953f, 0.1180f, 1.0973f, -0.1931f, -0.1999f, -0.0237f, 1.5181f, 0.0076f, -1.0830f, -1.5433f, 0.4327f, -0.9813f, 0.7875f, -0.4080f, -2.3144f, 1.5493f});
test.AddInput<float>("scale", channel_dims, {1.0f, 1.0f});
test.AddInput<float>("B", channel_dims, {0.0f, 0.0f});
test.AddInput<float>("mean", channel_dims, {1.0f, 2.0f});
test.AddInput<float>("var", channel_dims, {1.0f, 2.0f});
test.AddOutput<float>("Y", input_output_dims, {0.0131f, 0.5210f, 1.7244f, 0.1387f, -0.2708f, -0.1191f, 1.2089f, -0.0922f, -0.9548f, -1.5203f, 0.9077f, -0.8298f, 0.5796f, -0.4501f, -2.0921f, 1.2358f});
test.AddOutput<float>("running_mean", channel_dims, {-0.1754f, 0.303106f});
test.AddOutput<float>("running_var", channel_dims, {0.696052f, 1.41316f});
// Same exclusions as the opset 14 test
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider});
}
} // namespace test
} // namespace onnxruntime

View file

@ -4,26 +4,75 @@
namespace onnxruntime {
namespace test {
template<typename T>
void TestShape(const std::initializer_list<T>& data, const std::vector<int64_t>& shape)
{
template <typename T>
void TestShape(const std::initializer_list<T>& data, const std::vector<int64_t>& shape) {
OpTester test("Shape");
test.AddInput<T>("data", shape, data);
test.AddOutput<int64_t>("output", {static_cast<int64_t>(shape.size())}, shape);
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});//TensorRT parser: unsupported data types
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT parser: unsupported data types
}
TEST(ShapeOpTest, ShapeTestBool) { TestShape <bool> ({true, true, false, false, true, false}, {2, 3}); }
TEST(ShapeOpTest, ShapeTestFloat) { TestShape <float> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {2, 6}); }
TEST(ShapeOpTest, ShapeTestDouble) { TestShape <double> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {6, 2}); }
TEST(ShapeOpTest, ShapeTestInt8) { TestShape <int8_t> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {3, 4}); }
TEST(ShapeOpTest, ShapeTestInt16) { TestShape <int16_t> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {3, 4}); }
TEST(ShapeOpTest, ShapeTestInt32) { TestShape <int32_t> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {4, 3}); }
TEST(ShapeOpTest, ShapeTestInt64) { TestShape <int64_t> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 12}); }
TEST(ShapeOpTest, ShapeTestUint8) { TestShape <uint8_t> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {12, 1}); }
TEST(ShapeOpTest, ShapeTestUint16) { TestShape <uint16_t> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 12}); }
TEST(ShapeOpTest, ShapeTestUint32) { TestShape <uint32_t> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {12, 1}); }
TEST(ShapeOpTest, ShapeTestUint64) { TestShape <uint64_t> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 12}); }
TEST(ShapeOpTest, ShapeTestBool) { TestShape<bool>({true, true, false, false, true, false}, {2, 3}); }
TEST(ShapeOpTest, ShapeTestFloat) { TestShape<float>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {2, 6}); }
TEST(ShapeOpTest, ShapeTestDouble) { TestShape<double>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {6, 2}); }
TEST(ShapeOpTest, ShapeTestInt8) { TestShape<int8_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {3, 4}); }
TEST(ShapeOpTest, ShapeTestInt16) { TestShape<int16_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {3, 4}); }
TEST(ShapeOpTest, ShapeTestInt32) { TestShape<int32_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {4, 3}); }
TEST(ShapeOpTest, ShapeTestInt64) { TestShape<int64_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 12}); }
TEST(ShapeOpTest, ShapeTestUint8) { TestShape<uint8_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {12, 1}); }
TEST(ShapeOpTest, ShapeTestUint16) { TestShape<uint16_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 12}); }
TEST(ShapeOpTest, ShapeTestUint32) { TestShape<uint32_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {12, 1}); }
TEST(ShapeOpTest, ShapeTestUint64) { TestShape<uint64_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 12}); }
TEST(ShapeOpTest, ShapeTestString) { TestShape<std::string>({"1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"}, {1, 12}); }
TEST(ShapeOpTest, ShapeOpset15_Default) {
OpTester test("Shape", 15);
test.AddInput<int32_t>("data", {1, 2, 2}, {1, 2, 3, 4});
test.AddOutput<int64_t>("output", {3}, {1, 2, 2});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT parser: unsupported data types
}
TEST(ShapeOpTest, ShapeOpset15_StartOnly) {
OpTester test("Shape", 15);
test.AddAttribute<int64_t>("start", 1);
test.AddInput<int32_t>("data", {1, 2, 2}, {1, 2, 3, 4});
test.AddOutput<int64_t>("output", {2}, {2, 2});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT parser: unsupported data types
}
TEST(ShapeOpTest, ShapeOpset15_EndOnly) {
OpTester test("Shape", 15);
test.AddAttribute<int64_t>("end", 2);
test.AddInput<int32_t>("data", {1, 2, 2}, {1, 2, 3, 4});
test.AddOutput<int64_t>("output", {2}, {1, 2});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT parser: unsupported data types
}
TEST(ShapeOpTest, ShapeOpset15_StartAndEnd) {
OpTester test("Shape", 15);
test.AddAttribute<int64_t>("start", 1);
test.AddAttribute<int64_t>("end", 2);
test.AddInput<int32_t>("data", {1, 2, 2}, {1, 2, 3, 4});
test.AddOutput<int64_t>("output", {1}, {2});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT parser: unsupported data types
}
TEST(ShapeOpTest, ShapeOpset15_StartAndEndNegative) {
OpTester test("Shape", 15);
test.AddAttribute<int64_t>("start", -2);
test.AddAttribute<int64_t>("end", -1);
test.AddInput<int32_t>("data", {1, 2, 2}, {1, 2, 3, 4});
test.AddOutput<int64_t>("output", {1}, {2});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT parser: unsupported data types
}
TEST(ShapeOpTest, ShapeOpset15_StartAndEndProducingEmptySlice) {
OpTester test("Shape", 15);
test.AddAttribute<int64_t>("start", 2);
test.AddAttribute<int64_t>("end", 2);
test.AddInput<int32_t>("data", {1, 2, 2}, {1, 2, 3, 4});
test.AddOutput<int64_t>("output", {0}, {});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT parser: unsupported data types
}
} // namespace test
} // namespace onnxruntime

View file

@ -255,6 +255,14 @@
"BatchNormalization ai.onnx CPUExecutionProvider",
17832136363477464736
],
[
"BatchNormalization ai.onnx CPUExecutionProvider",
3016597991190826984
],
[
"BatchNormalization ai.onnx CPUExecutionProvider",
9270095107043637928
],
[
"BitShift ai.onnx CPUExecutionProvider",
4758677670685660688
@ -1483,6 +1491,10 @@
"Pow ai.onnx CPUExecutionProvider",
12963226513247425672
],
[
"Pow ai.onnx CPUExecutionProvider",
16138602580714332296
],
[
"PRelu ai.onnx CPUExecutionProvider",
3282999003886175808
@ -2159,6 +2171,10 @@
"Shape ai.onnx CPUExecutionProvider",
14989007508280400584
],
[
"Shape ai.onnx CPUExecutionProvider",
9917761852037658112
],
[
"Shrink ai.onnx CPUExecutionProvider",
4706529740707835200

View file

@ -67,7 +67,6 @@
"^test_add_uint8_cpu",
"^test_div_uint8_cpu",
// Following tests are for opset 15 ops and are not yet implemented in ORT
"^test_shape_*",
"^test_optional_*",
//GPU failures
"^test_batchnorm_epsilon_training_mode_cuda",

View file

@ -38,7 +38,6 @@
#include "core/optimizer/relu_clip_fusion.h"
#include "core/optimizer/reshape_fusion.h"
#include "core/optimizer/rule_based_graph_transformer.h"
#include "core/optimizer/shape_to_initializer.h"
#include "core/optimizer/skip_layer_norm_fusion.h"
#include "core/optimizer/slice_elimination.h"
#include "core/optimizer/unsqueeze_elimination.h"
@ -75,7 +74,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
case TransformerLevel::Level1: {
rule_transformer =
std::make_unique<RuleBasedGraphTransformer>(optimizer_utils::GenerateRuleBasedTransformerName(level),
compatible_eps);
compatible_eps);
rule_transformer->Register(std::make_unique<InsertMaxPoolOutput>());
rule_transformer->Register(std::make_unique<BatchNormReplacement>());
rule_transformer->Register(std::make_unique<UnsqueezeElimination>());
@ -127,16 +126,16 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
if (config.propagate_cast_ops_config.level >= 0) {
std::unordered_set<std::string> cuda_execution_provider = {onnxruntime::kCudaExecutionProvider, onnxruntime::kRocmExecutionProvider};
transformers.emplace_back(std::make_unique<PropagateCastOps>(config.propagate_cast_ops_config.strategy,
static_cast<size_t>(config.propagate_cast_ops_config.level),
config.propagate_cast_ops_config.allow,
cuda_execution_provider));
static_cast<size_t>(config.propagate_cast_ops_config.level),
config.propagate_cast_ops_config.allow,
cuda_execution_provider));
}
} break;
case TransformerLevel::Level2: {
rule_transformer =
std::make_unique<RuleBasedGraphTransformer>(optimizer_utils::GenerateRuleBasedTransformerName(level),
compatible_eps);
compatible_eps);
rule_transformer->Register(std::make_unique<ConcatReplacement>());
} break;