From b8599b786edf2bae5fcd78384f4ea43bf0a15959 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 16 Jan 2025 21:11:43 -0800 Subject: [PATCH] fix crash when first input of BatchNormalization is 1-D (#23387) ### Description fix crash when first input of BatchNormalization is 1-D --- onnxruntime/core/providers/cpu/nn/batch_norm.h | 4 +++- onnxruntime/core/providers/cpu/nn/batch_norm_helper.h | 3 +++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cpu/nn/batch_norm.h b/onnxruntime/core/providers/cpu/nn/batch_norm.h index be9bc3368e..1febfac86d 100644 --- a/onnxruntime/core/providers/cpu/nn/batch_norm.h +++ b/onnxruntime/core/providers/cpu/nn/batch_norm.h @@ -75,9 +75,11 @@ class BatchNorm : public OpKernel { const TensorShape& x_shape = X->Shape(); Tensor* Y = p_op_kernel_context->Output(0, x_shape); + // X shape is [N, C, D1, D2, ... Dn], but it can also be 1-D according to onnx spec: + // "The op also accepts single dimension input of size N in which case C is assumed to be 1" const auto& dims_vec = x_shape.GetDims(); const size_t N = onnxruntime::narrow(dims_vec[0]); - const size_t C = onnxruntime::narrow(dims_vec[1]); // assume NCHW as per the spec + const size_t C = dims_vec.size() == 1 ? 1 : onnxruntime::narrow(dims_vec[1]); // calculate sample_size (per individual channel) size_t sample_size = 1; diff --git a/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h b/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h index ccecbabfa3..b5aa522f71 100644 --- a/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h +++ b/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h @@ -28,6 +28,9 @@ class BatchNormHelper { // NHWC dependent shape: X // All other shapes are assumed to be in NCHW layout? const auto& x_dims = X->Shape().GetDims(); + if (x_dims.size() < 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input X: NumDimensions() < 1"); + } // If x_dims size < 2, num_channels defaults to 1. int64_t num_channels;