diff --git a/onnxruntime/core/providers/cpu/nn/batch_norm.h b/onnxruntime/core/providers/cpu/nn/batch_norm.h index be9bc3368e..1febfac86d 100644 --- a/onnxruntime/core/providers/cpu/nn/batch_norm.h +++ b/onnxruntime/core/providers/cpu/nn/batch_norm.h @@ -75,9 +75,11 @@ class BatchNorm : public OpKernel { const TensorShape& x_shape = X->Shape(); Tensor* Y = p_op_kernel_context->Output(0, x_shape); + // X shape is [N, C, D1, D2, ... Dn], but it can also be 1-D according to onnx spec: + // "The op also accepts single dimension input of size N in which case C is assumed to be 1" const auto& dims_vec = x_shape.GetDims(); const size_t N = onnxruntime::narrow(dims_vec[0]); - const size_t C = onnxruntime::narrow(dims_vec[1]); // assume NCHW as per the spec + const size_t C = dims_vec.size() == 1 ? 1 : onnxruntime::narrow(dims_vec[1]); // calculate sample_size (per individual channel) size_t sample_size = 1; diff --git a/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h b/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h index ccecbabfa3..b5aa522f71 100644 --- a/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h +++ b/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h @@ -28,6 +28,9 @@ class BatchNormHelper { // NHWC dependent shape: X // All other shapes are assumed to be in NCHW layout? const auto& x_dims = X->Shape().GetDims(); + if (x_dims.size() < 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input X: NumDimensions() < 1"); + } // If x_dims size < 2, num_channels defaults to 1. int64_t num_channels;