diff --git a/onnxruntime/core/providers/cpu/object_detection/roialign.h b/onnxruntime/core/providers/cpu/object_detection/roialign.h index b4ab3dc600..98f7218da5 100644 --- a/onnxruntime/core/providers/cpu/object_detection/roialign.h +++ b/onnxruntime/core/providers/cpu/object_detection/roialign.h @@ -15,7 +15,6 @@ enum struct RoiAlignMode { max }; -template class RoiAlignBase { public: explicit RoiAlignBase(const OpKernelInfo& info) { @@ -23,7 +22,11 @@ class RoiAlignBase { std::string mode; if (info.GetAttr("mode", &mode).IsOK()) { std::transform(mode.begin(), mode.end(), mode.begin(), [](char i) { return static_cast(::tolower(i)); }); - if (mode != "avg" && mode != "max") { + if (mode == "avg") { + mode_ = RoiAlignMode::avg; + } else if (mode == "max") { + mode_ = RoiAlignMode::max; + } else { ORT_THROW("Invalid mode of value ", mode, " specified. It should be either avg or max"); } mode_ = mode == "avg" ? RoiAlignMode::avg : RoiAlignMode::max; @@ -67,9 +70,9 @@ class RoiAlignBase { }; template -class RoiAlign final : public OpKernel, public RoiAlignBase { +class RoiAlign final : public OpKernel, public RoiAlignBase { public: - explicit RoiAlign(const OpKernelInfo& info) : OpKernel(info), RoiAlignBase(info) {} + explicit RoiAlign(const OpKernelInfo& info) : OpKernel(info), RoiAlignBase(info) {} Status Compute(OpKernelContext* context) const override; diff --git a/onnxruntime/core/providers/cuda/object_detection/roialign.h b/onnxruntime/core/providers/cuda/object_detection/roialign.h index fdd0f95ccf..8e33e3d2ea 100644 --- a/onnxruntime/core/providers/cuda/object_detection/roialign.h +++ b/onnxruntime/core/providers/cuda/object_detection/roialign.h @@ -1,5 +1,5 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. #pragma once @@ -12,8 +12,8 @@ namespace onnxruntime { namespace cuda { template -struct RoiAlign final : CudaKernel, RoiAlignBase { - RoiAlign(const OpKernelInfo& info) : CudaKernel(info), RoiAlignBase(info) {} +struct RoiAlign final : CudaKernel, RoiAlignBase { + RoiAlign(const OpKernelInfo& info) : CudaKernel(info), RoiAlignBase(info) {} Status ComputeInternal(OpKernelContext* context) const override;