mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
fix crash when first input of BatchNormalization is 1-D (#23387)
### Description fix crash when first input of BatchNormalization is 1-D
This commit is contained in:
parent
09c4cc7b36
commit
b8599b786e
2 changed files with 6 additions and 1 deletions
|
|
@ -75,9 +75,11 @@ class BatchNorm : public OpKernel {
|
|||
const TensorShape& x_shape = X->Shape();
|
||||
Tensor* Y = p_op_kernel_context->Output(0, x_shape);
|
||||
|
||||
// X shape is [N, C, D1, D2, ... Dn], but it can also be 1-D according to onnx spec:
|
||||
// "The op also accepts single dimension input of size N in which case C is assumed to be 1"
|
||||
const auto& dims_vec = x_shape.GetDims();
|
||||
const size_t N = onnxruntime::narrow<size_t>(dims_vec[0]);
|
||||
const size_t C = onnxruntime::narrow<size_t>(dims_vec[1]); // assume NCHW as per the spec
|
||||
const size_t C = dims_vec.size() == 1 ? 1 : onnxruntime::narrow<size_t>(dims_vec[1]);
|
||||
|
||||
// calculate sample_size (per individual channel)
|
||||
size_t sample_size = 1;
|
||||
|
|
|
|||
|
|
@ -28,6 +28,9 @@ class BatchNormHelper {
|
|||
// NHWC dependent shape: X
|
||||
// All other shapes are assumed to be in NCHW layout?
|
||||
const auto& x_dims = X->Shape().GetDims();
|
||||
if (x_dims.size() < 1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input X: NumDimensions() < 1");
|
||||
}
|
||||
|
||||
// If x_dims size < 2, num_channels defaults to 1.
|
||||
int64_t num_channels;
|
||||
|
|
|
|||
Loading…
Reference in a new issue