fix crash when first input of BatchNormalization is 1-D (#23387)

### Description

fix crash when first input of BatchNormalization is 1-D
This commit is contained in:
Yulong Wang 2025-01-16 21:11:43 -08:00 committed by GitHub
parent 09c4cc7b36
commit b8599b786e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 6 additions and 1 deletions

View file

@ -75,9 +75,11 @@ class BatchNorm : public OpKernel {
const TensorShape& x_shape = X->Shape();
Tensor* Y = p_op_kernel_context->Output(0, x_shape);
// X shape is [N, C, D1, D2, ... Dn], but it can also be 1-D according to onnx spec:
// "The op also accepts single dimension input of size N in which case C is assumed to be 1"
const auto& dims_vec = x_shape.GetDims();
const size_t N = onnxruntime::narrow<size_t>(dims_vec[0]);
const size_t C = onnxruntime::narrow<size_t>(dims_vec[1]); // assume NCHW as per the spec
const size_t C = dims_vec.size() == 1 ? 1 : onnxruntime::narrow<size_t>(dims_vec[1]);
// calculate sample_size (per individual channel)
size_t sample_size = 1;

View file

@ -28,6 +28,9 @@ class BatchNormHelper {
// NHWC dependent shape: X
// All other shapes are assumed to be in NCHW layout?
const auto& x_dims = X->Shape().GetDims();
if (x_dims.size() < 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input X: NumDimensions() < 1");
}
// If x_dims size < 2, num_channels defaults to 1.
int64_t num_channels;