Remove template<T> from RoiAlignBase (#4558)

This commit is contained in:
Tracy Sharpe 2020-07-20 14:28:46 -07:00 committed by GitHub
parent bbdabc2c48
commit 7f9d9557b1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 8 deletions

View file

@ -15,7 +15,6 @@ enum struct RoiAlignMode {
max
};
template <typename T>
class RoiAlignBase {
public:
explicit RoiAlignBase(const OpKernelInfo& info) {
@ -23,7 +22,11 @@ class RoiAlignBase {
std::string mode;
if (info.GetAttr<std::string>("mode", &mode).IsOK()) {
std::transform(mode.begin(), mode.end(), mode.begin(), [](char i) { return static_cast<char>(::tolower(i)); });
if (mode != "avg" && mode != "max") {
if (mode == "avg") {
mode_ = RoiAlignMode::avg;
} else if (mode == "max") {
mode_ = RoiAlignMode::max;
} else {
ORT_THROW("Invalid mode of value ", mode, " specified. It should be either avg or max");
}
mode_ = mode == "avg" ? RoiAlignMode::avg : RoiAlignMode::max;
@ -67,9 +70,9 @@ class RoiAlignBase {
};
template <typename T>
class RoiAlign final : public OpKernel, public RoiAlignBase<T> {
class RoiAlign final : public OpKernel, public RoiAlignBase {
public:
explicit RoiAlign(const OpKernelInfo& info) : OpKernel(info), RoiAlignBase<T>(info) {}
explicit RoiAlign(const OpKernelInfo& info) : OpKernel(info), RoiAlignBase(info) {}
Status Compute(OpKernelContext* context) const override;

View file

@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
@ -12,8 +12,8 @@ namespace onnxruntime {
namespace cuda {
template <typename T>
struct RoiAlign final : CudaKernel, RoiAlignBase<T> {
RoiAlign(const OpKernelInfo& info) : CudaKernel(info), RoiAlignBase<T>(info) {}
struct RoiAlign final : CudaKernel, RoiAlignBase {
RoiAlign(const OpKernelInfo& info) : CudaKernel(info), RoiAlignBase(info) {}
Status ComputeInternal(OpKernelContext* context) const override;