From ae66d0e7cf6774dc1b6435e122d3589251e6fbc8 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Tue, 24 Sep 2024 11:58:48 +1000 Subject: [PATCH] Update ROCm reduction to match recent CUDA change (#22192) ### Description Add handling of a missing optional axes input to the ROCm reduction ops. Matches CUDA EP change from #22149 ### Motivation and Context Fix pipeline. --- onnxruntime/core/providers/rocm/reduction/reduction_ops.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc b/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc index 11073ab358..a1f5eba9a2 100644 --- a/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc @@ -731,10 +731,9 @@ Status ReduceKernel::ComputeImpl(OpKernelContext* ctx, miopenR std::vector axes; size_t num_inputs = ctx->InputCount(); - if (num_inputs == 2) { + const Tensor* axes_tensor = num_inputs == 2 ? ctx->Input(1) : nullptr; // optional input. may be nullptr. + if (axes_tensor != nullptr) { // override the attribute value with the input value for reduction_axes - const Tensor* axes_tensor = ctx->Input(1); - ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null"); ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, "An axes tensor must be a vector tensor."); auto nDims = static_cast(axes_tensor->Shape()[0]); const auto* data = axes_tensor->Data();