From efdd96595f5120af20ac8ffb8b1ad0917e8553da Mon Sep 17 00:00:00 2001 From: "Tang, Cheng" Date: Thu, 27 Aug 2020 16:10:53 -0700 Subject: [PATCH] bfloat16 and opset13 related fix (#4913) * regsiter part of opset13 cpu kernels; fix a bug in func impl; adjust reshapefusion order * remove useless function Co-authored-by: Cheng Tang --- onnxruntime/core/graph/function.cc | 9 +- .../providers/cpu/activation/activations.cc | 10 +- .../providers/cpu/cpu_execution_provider.cc | 100 +++++++++++++----- .../providers/cpu/math/element_wise_ops.cc | 9 +- .../core/providers/cpu/tensor/unsqueeze.cc | 11 +- .../core/optimizer/graph_transformer_utils.cc | 3 +- 6 files changed, 101 insertions(+), 41 deletions(-) diff --git a/onnxruntime/core/graph/function.cc b/onnxruntime/core/graph/function.cc index 65c4533412..9e647ab759 100644 --- a/onnxruntime/core/graph/function.cc +++ b/onnxruntime/core/graph/function.cc @@ -220,10 +220,6 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph, ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); } -static std::unordered_map GetOpsetVersionMap(const ONNX_NAMESPACE::FunctionProto& onnx_func_proto) { - return std::unordered_map{{onnxruntime::kOnnxDomain, static_cast(onnx_func_proto.since_version())}}; -} - FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph, const onnxruntime::NodeIndex& node_index, const ONNX_NAMESPACE::FunctionProto& onnx_func_proto, @@ -232,7 +228,7 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph, body_(onnx_func_proto.name(), false, onnxruntime::ModelMetaData(), graph.ModelPath().ToPathString(), IOnnxRuntimeOpSchemaRegistryList(), - GetOpsetVersionMap(onnx_func_proto), {}, logger), + graph.DomainToVersionMap(), {}, logger), onnx_func_proto_(onnx_func_proto) { // Make a copy of the FunctionProto. // All FunctionBody ops with the same op type seem to share the same FunctionProto struct within a model. @@ -298,9 +294,6 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph, op_schema_->Finalize(); //construct body - std::unordered_map domain_to_version; - //TODO: set correct domain and version - domain_to_version[onnxruntime::kOnnxDomain] = static_cast(onnx_func_proto_.since_version()); auto& function_body_graph = body_.MainGraph(); std::vector graph_inputs(node_in_parent_graph->InputDefs().size(), nullptr), graph_outputs(node_in_parent_graph->OutputDefs().size(), nullptr); diff --git a/onnxruntime/core/providers/cpu/activation/activations.cc b/onnxruntime/core/providers/cpu/activation/activations.cc index 77fbc37064..34fd9e2e5e 100644 --- a/onnxruntime/core/providers/cpu/activation/activations.cc +++ b/onnxruntime/core/providers/cpu/activation/activations.cc @@ -49,6 +49,14 @@ template Status ElementWiseRangedTransform::Create(const std::string& typ #define REGISTER_UNARY_ELEMENTWISE_KERNEL(x, sinceVersion) REGISTER_UNARY_ELEMENTWISE_KERNEL_ALIAS(x, x, sinceVersion) +#define REGISTER_VERSIONED_UNARY_ELEMENTWISE_KERNEL_ALIAS(alias, x, sinceVersion, firstEnd, newVersion) \ + ONNX_CPU_OPERATOR_VERSIONED_KERNEL( \ + alias, sinceVersion, firstEnd, \ + KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()), x); \ + ONNX_CPU_OPERATOR_KERNEL( \ + alias, newVersion, \ + KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()), x); + REGISTER_UNARY_ELEMENTWISE_KERNEL(Elu, 6); REGISTER_UNARY_ELEMENTWISE_KERNEL(HardSigmoid, 6); REGISTER_UNARY_ELEMENTWISE_KERNEL(LeakyRelu, 6); @@ -57,7 +65,7 @@ REGISTER_UNARY_ELEMENTWISE_KERNEL(Selu, 6); REGISTER_UNARY_ELEMENTWISE_KERNEL(Sigmoid, 6); REGISTER_UNARY_ELEMENTWISE_KERNEL(Softplus, 1); REGISTER_UNARY_ELEMENTWISE_KERNEL(Softsign, 1); -REGISTER_UNARY_ELEMENTWISE_KERNEL(Tanh, 6); +REGISTER_VERSIONED_UNARY_ELEMENTWISE_KERNEL_ALIAS(Tanh, Tanh, 6, 12, 13); REGISTER_UNARY_ELEMENTWISE_KERNEL(ThresholdedRelu, 10); namespace functors { diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 2382a8d24f..ca89e8a060 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -38,7 +38,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Sel class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Sigmoid); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Softplus); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Softsign); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Tanh); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, Tanh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Tanh); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, PRelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, RandomNormal); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, RandomUniform); @@ -198,18 +199,18 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, float, Upsample); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, int32_t, Upsample); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, uint8_t, Upsample); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Expand); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, double, Expand); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, int8_t, Expand); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, int16_t, Expand); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, int32_t, Expand); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, int64_t, Expand); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, uint8_t, Expand); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, uint16_t, Expand); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, uint32_t, Expand); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, uint64_t, Expand); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, bool, Expand); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, MLFloat16, Expand); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 12, float, Expand); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 12, double, Expand); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 12, int8_t, Expand); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 12, int16_t, Expand); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 12, int32_t, Expand); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 12, int64_t, Expand); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 12, uint8_t, Expand); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 12, uint16_t, Expand); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 12, uint32_t, Expand); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 12, uint64_t, Expand); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 12, bool, Expand); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 12, MLFloat16, Expand); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 8, Scan); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, If); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, Loop); @@ -355,7 +356,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Ga class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Slice); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Split); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Squeeze); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Unsqueeze); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, Unsqueeze); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Det); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, ScatterElements); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, NonMaxSuppression); @@ -429,6 +430,21 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, GatherND); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, Einsum); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Sign); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Unsqueeze); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Expand); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Expand); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int8_t, Expand); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int16_t, Expand); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, Expand); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, Expand); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint8_t, Expand); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint16_t, Expand); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint32_t, Expand); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint64_t, Expand); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, bool, Expand); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16, Expand); + // REVIEW(codemzs): ConstEigenVectorArrayMap.cast, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -713,29 +730,29 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { int32_t, Upsample)>, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -939,7 +956,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // REVIEW(codemzs): ConstEigenVectorArrayMap.cast, diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index e044a45f9f..01587a64da 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -1070,9 +1070,16 @@ Status Expand_8::Compute(OpKernelContext* context) const { } #define REG_EXPAND_KERNEL(TYPE) \ - ONNX_CPU_OPERATOR_TYPED_KERNEL( \ + ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \ Expand, \ 8, \ + 12, \ + TYPE, \ + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Expand_8); \ + ONNX_CPU_OPERATOR_TYPED_KERNEL( \ + Expand, \ + 13, \ TYPE, \ KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ Expand_8); diff --git a/onnxruntime/core/providers/cpu/tensor/unsqueeze.cc b/onnxruntime/core/providers/cpu/tensor/unsqueeze.cc index 39be162f0b..4dbd286b80 100644 --- a/onnxruntime/core/providers/cpu/tensor/unsqueeze.cc +++ b/onnxruntime/core/providers/cpu/tensor/unsqueeze.cc @@ -18,9 +18,18 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( .TypeConstraint("T", DataTypeImpl::AllTensorTypes()), Unsqueeze); -ONNX_CPU_OPERATOR_KERNEL( +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( Unsqueeze, 11, + 12, + KernelDefBuilder() + .Alias(0, 0) + .TypeConstraint("T", DataTypeImpl::AllTensorTypes()), + Unsqueeze); + +ONNX_CPU_OPERATOR_KERNEL( + Unsqueeze, + 13, KernelDefBuilder() .Alias(0, 0) .TypeConstraint("T", DataTypeImpl::AllTensorTypes()), diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 91b6d37846..fc16345af3 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -90,6 +90,7 @@ std::vector> GeneratePreTrainingTransformers( } transformers.emplace_back(onnxruntime::make_unique(execution_provider, compatible_eps, weights_to_train)); + transformers.emplace_back(onnxruntime::make_unique(compatible_eps)); auto horizontal_parallel_size = training::DistributedRunContext::GroupSize(training::WorkerGroupType::HorizontalParallel); if (horizontal_parallel_size > 1) { LOGS_DEFAULT(WARNING) << horizontal_parallel_size << "-way horizontal model parallel is enabled"; @@ -101,8 +102,6 @@ std::vector> GeneratePreTrainingTransformers( } break; case TransformerLevel::Level2: { - // Put ReshapeFusion as level-2 optimization after all level-1 graph rewriters are run. - transformers.emplace_back(onnxruntime::make_unique(compatible_eps)); rule_transformer = onnxruntime::make_unique(optimizer_utils::GenerateRuleBasedTransformerName(level), compatible_eps);