mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-05 04:17:53 +00:00
Update ROCm reduction to match recent CUDA change (#22192)
### Description <!-- Describe your changes. --> Add handling of a missing optional axes input to the ROCm reduction ops. Matches CUDA EP change from #22149 ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Fix pipeline.
This commit is contained in:
parent
0806879ad4
commit
ae66d0e7cf
1 changed files with 2 additions and 3 deletions
|
|
@ -731,10 +731,9 @@ Status ReduceKernel<allow_multi_axes>::ComputeImpl(OpKernelContext* ctx, miopenR
|
|||
std::vector<int64_t> axes;
|
||||
|
||||
size_t num_inputs = ctx->InputCount();
|
||||
if (num_inputs == 2) {
|
||||
const Tensor* axes_tensor = num_inputs == 2 ? ctx->Input<Tensor>(1) : nullptr; // optional input. may be nullptr.
|
||||
if (axes_tensor != nullptr) {
|
||||
// override the attribute value with the input value for reduction_axes
|
||||
const Tensor* axes_tensor = ctx->Input<Tensor>(1);
|
||||
ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null");
|
||||
ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, "An axes tensor must be a vector tensor.");
|
||||
auto nDims = static_cast<size_t>(axes_tensor->Shape()[0]);
|
||||
const auto* data = axes_tensor->Data<int64_t>();
|
||||
|
|
|
|||
Loading…
Reference in a new issue