From 9d89b23d81bbbf9d90fe5e69d4d90cddfb71dbd0 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Tue, 23 Apr 2019 21:37:21 -0700 Subject: [PATCH] BatchNorm CPU does not support non-spatial cases - explicitly handle such cases (#890) * BatchNorm CPU does not support non-spatial cases * skip test in c# * Update comments --- .../Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs | 10 +++++----- onnxruntime/core/providers/cpu/nn/batch_norm.h | 7 +++++++ 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs index a96d885399..c57a72deac 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs @@ -211,12 +211,12 @@ namespace Microsoft.ML.OnnxRuntime.Tests [Fact] private void TestPreTrainedModelsOpset7And8() { - // 16-bit float not supported type in C#. var skipModels = new List() { - "tf_inception_v2", - "fp16_inception_v1", - "fp16_shufflenet", - "fp16_tiny_yolov2" }; + "mxnet_arcface", // Model not supported by CPU execution provider + "tf_inception_v2", // TODO: Debug failing model, skipping for now + "fp16_inception_v1", // 16-bit float not supported type in C#. + "fp16_shufflenet", // 16-bit float not supported type in C#. + "fp16_tiny_yolov2" }; // 16-bit float not supported type in C#. var disableContribOpsEnvVar = Environment.GetEnvironmentVariable("DisableContribOps"); var isContribOpsDisabled = (disableContribOpsEnvVar != null) ? disableContribOpsEnvVar.Equals("ON") : false; diff --git a/onnxruntime/core/providers/cpu/nn/batch_norm.h b/onnxruntime/core/providers/cpu/nn/batch_norm.h index f4d634e923..00109bbaaa 100644 --- a/onnxruntime/core/providers/cpu/nn/batch_norm.h +++ b/onnxruntime/core/providers/cpu/nn/batch_norm.h @@ -32,6 +32,13 @@ class BatchNorm : public OpKernel { explicit BatchNorm(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info) { auto st = op_kernel_info.GetAttr("epsilon", &epsilon_); ORT_ENFORCE(st.IsOK(), st.ErrorMessage()); + + // opset 6-8 + int64_t spatial; + if (op_kernel_info.GetAttr("spatial", &spatial).IsOK()) { + ORT_ENFORCE(spatial == 1, "BatchNormalization kernel for CPU provider does not support non-spatial cases"); + } + //TODO: momentum }