mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Remove template<T> from RoiAlignBase (#4558)
This commit is contained in:
parent
bbdabc2c48
commit
7f9d9557b1
2 changed files with 11 additions and 8 deletions
|
|
@ -15,7 +15,6 @@ enum struct RoiAlignMode {
|
|||
max
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class RoiAlignBase {
|
||||
public:
|
||||
explicit RoiAlignBase(const OpKernelInfo& info) {
|
||||
|
|
@ -23,7 +22,11 @@ class RoiAlignBase {
|
|||
std::string mode;
|
||||
if (info.GetAttr<std::string>("mode", &mode).IsOK()) {
|
||||
std::transform(mode.begin(), mode.end(), mode.begin(), [](char i) { return static_cast<char>(::tolower(i)); });
|
||||
if (mode != "avg" && mode != "max") {
|
||||
if (mode == "avg") {
|
||||
mode_ = RoiAlignMode::avg;
|
||||
} else if (mode == "max") {
|
||||
mode_ = RoiAlignMode::max;
|
||||
} else {
|
||||
ORT_THROW("Invalid mode of value ", mode, " specified. It should be either avg or max");
|
||||
}
|
||||
mode_ = mode == "avg" ? RoiAlignMode::avg : RoiAlignMode::max;
|
||||
|
|
@ -67,9 +70,9 @@ class RoiAlignBase {
|
|||
};
|
||||
|
||||
template <typename T>
|
||||
class RoiAlign final : public OpKernel, public RoiAlignBase<T> {
|
||||
class RoiAlign final : public OpKernel, public RoiAlignBase {
|
||||
public:
|
||||
explicit RoiAlign(const OpKernelInfo& info) : OpKernel(info), RoiAlignBase<T>(info) {}
|
||||
explicit RoiAlign(const OpKernelInfo& info) : OpKernel(info), RoiAlignBase(info) {}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
|
@ -12,8 +12,8 @@ namespace onnxruntime {
|
|||
namespace cuda {
|
||||
|
||||
template <typename T>
|
||||
struct RoiAlign final : CudaKernel, RoiAlignBase<T> {
|
||||
RoiAlign(const OpKernelInfo& info) : CudaKernel(info), RoiAlignBase<T>(info) {}
|
||||
struct RoiAlign final : CudaKernel, RoiAlignBase {
|
||||
RoiAlign(const OpKernelInfo& info) : CudaKernel(info), RoiAlignBase(info) {}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue