BatchNorm CPU does not support non-spatial cases - explicitly handle such cases (#890)

* BatchNorm CPU does not support non-spatial cases

* skip test in c#

* Update comments
This commit is contained in:
Hariharan Seshadri 2019-04-23 21:37:21 -07:00 committed by GitHub
parent d0f846aad5
commit 9d89b23d81
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 5 deletions

View file

@ -211,12 +211,12 @@ namespace Microsoft.ML.OnnxRuntime.Tests
[Fact]
private void TestPreTrainedModelsOpset7And8()
{
// 16-bit float not supported type in C#.
var skipModels = new List<String>() {
"tf_inception_v2",
"fp16_inception_v1",
"fp16_shufflenet",
"fp16_tiny_yolov2" };
"mxnet_arcface", // Model not supported by CPU execution provider
"tf_inception_v2", // TODO: Debug failing model, skipping for now
"fp16_inception_v1", // 16-bit float not supported type in C#.
"fp16_shufflenet", // 16-bit float not supported type in C#.
"fp16_tiny_yolov2" }; // 16-bit float not supported type in C#.
var disableContribOpsEnvVar = Environment.GetEnvironmentVariable("DisableContribOps");
var isContribOpsDisabled = (disableContribOpsEnvVar != null) ? disableContribOpsEnvVar.Equals("ON") : false;

View file

@ -32,6 +32,13 @@ class BatchNorm : public OpKernel {
explicit BatchNorm(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info) {
auto st = op_kernel_info.GetAttr<float>("epsilon", &epsilon_);
ORT_ENFORCE(st.IsOK(), st.ErrorMessage());
// opset 6-8
int64_t spatial;
if (op_kernel_info.GetAttr<int64_t>("spatial", &spatial).IsOK()) {
ORT_ENFORCE(spatial == 1, "BatchNormalization kernel for CPU provider does not support non-spatial cases");
}
//TODO: momentum
}