mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-07 00:13:17 +00:00
Add FusedConv Op to ROCm (#11792)
* [ROCm] Add FusedConv Op. * Enable ROCm for FusedConvTest * [ROCm] Implement FusedConv Op. with Fusion API The old code path was left as the fallback since some combinations are not supported (e.g., FusedConvTest.Conv2D_Bias_Z_Relu as of ROCM 5.1, where to bias layers are needed). * [ROCM] Suppress duplicated warnings in unsupported Fusion API usage. Know limitation for current MIOpen (verified with ROCM 5.2): Only one bias layer may present in the Fusion Plan. Adding the second bias operation to the Fusion plan will end up with miopenStatusUnsupportedOp. In this case the fallback code path will be taken to complete required FusedConv operation. However, previously this failure was not detected and cached, and applications that create multiple FusedConv Ops with both z and bias will keep printing error messages, which is annoying to end users while this message is mainly for developers. This commit will let it print the first error message as a reminder, and skip the Fusion API code path in following calls if both z and bias present. (Note: the skipping applies to all newly created FusedConv Ops). * [ROCM] Add cache mechanism for FusedConv Op. Now the operator with the same configuration will share the same Fusion Plan object, and the creation result will also be cached. Two benefits: 1. No duplicated Fusion plan creation, which is a presumably very costly process. 2. Failures due to MIOpen limitations (like z and b cannot present at the same time) will only be triggered once. Know limits: Due to the limitation of MIOpen Interface, the tensor order of the convolution operator can only be guessed.
This commit is contained in:
parent
eb827bd3e5
commit
bc353c7afe
3 changed files with 460 additions and 10 deletions
451
onnxruntime/contrib_ops/rocm/fused_conv.cc
Normal file
451
onnxruntime/contrib_ops/rocm/fused_conv.cc
Normal file
|
|
@ -0,0 +1,451 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <unordered_set>
|
||||
#include <unordered_map>
|
||||
#include "core/providers/rocm/nn/conv.h"
|
||||
#include "core/providers/rocm/rocm_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
||||
namespace {
|
||||
|
||||
// Copied from hipDNN/library/src/hcc_detail/hipdnn_miopen.cpp
|
||||
miopenStatus_t _miopenAddTensor(
|
||||
miopenHandle_t handle,
|
||||
const void *alpha,
|
||||
const miopenTensorDescriptor_t aDesc,
|
||||
const void *A,
|
||||
const void *beta,
|
||||
const miopenTensorDescriptor_t cDesc,
|
||||
void *C,
|
||||
const void* zero_scalar)
|
||||
{
|
||||
const miopenTensorOp_t tensorOp = miopenTensorOpAdd;
|
||||
// opnd2 = Add ( 0.0 * opnd0, alpha * opnd1 ) + alpha * opnd2
|
||||
return miopenOpTensor(handle, tensorOp,
|
||||
zero_scalar, cDesc, C,
|
||||
alpha, aDesc, A,
|
||||
alpha, cDesc, C);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template<uint32_t BASIS = 0x811C9DC5, uint32_t PRIME = 0x01000193>
|
||||
struct FNVHash {
|
||||
uint32_t GetValue() const { return value_; }
|
||||
|
||||
void Hash(const void* in_ptr, size_t nbytes) {
|
||||
auto ptr = reinterpret_cast<const uint8_t*>(in_ptr);
|
||||
for (size_t i = 0; i < nbytes; ++i) {
|
||||
value_ ^= ptr[i];
|
||||
value_ *= PRIME;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename std::enable_if<std::is_trivially_copyable<T>::value, size_t>::type = 0>
|
||||
FNVHash& operator<<(const T& pod) {
|
||||
Hash(&pod, sizeof(pod));
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
FNVHash& operator<<(const std::vector<T>& pod_array) {
|
||||
for (const auto& pod : pod_array) {
|
||||
(*this) << pod;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
void HashTensor(miopenTensorDescriptor_t tdesc) {
|
||||
int size = 0;
|
||||
miopenGetTensorDescriptorSize(tdesc, &size);
|
||||
(*this) << size;
|
||||
std::vector<int> dims(size);
|
||||
std::vector<int> strides(size);
|
||||
miopenDataType_t dtype;
|
||||
miopenGetTensorDescriptor(tdesc, &dtype, dims.data(), strides.data());
|
||||
(*this) << dtype;
|
||||
(*this) << dims;
|
||||
(*this) << strides;
|
||||
}
|
||||
|
||||
void HashConvolutionDescriptor(miopenConvolutionDescriptor_t cdesc) {
|
||||
int spatial_dim = 1;
|
||||
// Current MIOpen doesn't provide API to probe the dimension of a
|
||||
// miopenConvolutionDescriptor_t, so we have to guess.
|
||||
// This algorithm is based on a specific behavior of miopenGetConvolutionNdDescriptor,
|
||||
// which fails when requestedSpatialDim > the convolution's spatial dimension
|
||||
std::vector<int> spatial_dims;
|
||||
std::vector<int> pads;
|
||||
std::vector<int> strides;
|
||||
std::vector<int> dilations;
|
||||
miopenConvolutionMode_t mode;
|
||||
while (true) {
|
||||
spatial_dims.resize(spatial_dim);
|
||||
pads.resize(spatial_dim);
|
||||
strides.resize(spatial_dim);
|
||||
dilations.resize(spatial_dim);
|
||||
|
||||
if (miopenStatusSuccess != miopenGetConvolutionNdDescriptor(cdesc,
|
||||
spatial_dim,
|
||||
spatial_dims.data(),
|
||||
pads.data(),
|
||||
strides.data(),
|
||||
dilations.data(),
|
||||
&mode)) {
|
||||
// Remove the extra dimension
|
||||
spatial_dims.resize(spatial_dim - 1);
|
||||
pads.resize(spatial_dim - 1);
|
||||
strides.resize(spatial_dim - 1);
|
||||
dilations.resize(spatial_dim - 1);
|
||||
(*this) << spatial_dim;
|
||||
(*this) << spatial_dims;
|
||||
(*this) << pads;
|
||||
(*this) << strides;
|
||||
(*this) << dilations;
|
||||
break;
|
||||
}
|
||||
spatial_dim += 1;
|
||||
ORT_ENFORCE(spatial_dim < 10,
|
||||
"miopenGetConvolutionNdDescriptor is supposed to fail before spatial_dim gets to ",
|
||||
spatial_dim);
|
||||
}
|
||||
}
|
||||
private:
|
||||
uint32_t value_ = BASIS;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class FusedConv : public onnxruntime::rocm::Conv<T> {
|
||||
public:
|
||||
using Base = onnxruntime::rocm::Conv<T>;
|
||||
FusedConv(const OpKernelInfo& info) : onnxruntime::rocm::Conv<T>(info) {
|
||||
std::string activation;
|
||||
if (info.GetAttr<std::string>("activation", &activation) == Status::OK() &&
|
||||
MapMode(activation) == Status::OK() &&
|
||||
miopenCreateActivationDescriptor(&activation_desc_) == miopenStatusSuccess) {
|
||||
status_ = miopenSetActivationDescriptor(activation_desc_,
|
||||
activation_mode_,
|
||||
0.0, 0.0, 0.0);
|
||||
}
|
||||
}
|
||||
|
||||
ORT_DISALLOW_COPY_AND_ASSIGNMENT(FusedConv);
|
||||
|
||||
~FusedConv() {
|
||||
if (activation_desc_) {
|
||||
miopenDestroyActivationDescriptor(activation_desc_);
|
||||
status_ = miopenStatusNotInitialized;
|
||||
activation_desc_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override {
|
||||
MIOPEN_RETURN_IF_ERROR(status_);
|
||||
std::lock_guard<OrtMutex> lock(Base::s_.mutex);
|
||||
|
||||
ORT_RETURN_IF_ERROR(Base::UpdateState(context, true));
|
||||
if (Base::s_.Y->Shape().Size() == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool has_z = nullptr != Base::s_.z_data;
|
||||
bool has_b = nullptr != Base::s_.b_data;
|
||||
auto factory = [this](FusedConvFusionData& fusion) {
|
||||
return this->DoCreateFusionDesc(fusion);
|
||||
};
|
||||
auto& cached_item = plan_cache_.FindOrCreateFusionPlanCache(Hash(),
|
||||
factory);
|
||||
bool should_try_fusion_api = cached_item.Validate(Base::MiopenHandle());
|
||||
|
||||
typedef typename onnxruntime::rocm::ToHipType<T>::MappedType HipT;
|
||||
const auto alpha = onnxruntime::rocm::Consts<HipT>::One;
|
||||
const auto beta = onnxruntime::rocm::Consts<HipT>::Zero;
|
||||
IAllocatorUniquePtr<void> workspace = Base::GetWorkSpace();
|
||||
miopenStatus_t fusion_status = miopenStatusNotInitialized;
|
||||
|
||||
if (should_try_fusion_api) {
|
||||
auto& fusion_info = *cached_item.fusion;
|
||||
MIOPEN_RETURN_IF_ERROR(miopenSetOpArgsConvForward(fusion_info.fusion_args,
|
||||
fusion_info.conv_op,
|
||||
&alpha,
|
||||
&beta,
|
||||
Base::s_.w_data));
|
||||
if (has_z) {
|
||||
MIOPEN_RETURN_IF_ERROR(miopenSetOpArgsBiasForward(fusion_info.fusion_args,
|
||||
fusion_info.bias_z_op,
|
||||
&alpha,
|
||||
&beta,
|
||||
Base::s_.z_data));
|
||||
}
|
||||
if (has_b) {
|
||||
MIOPEN_RETURN_IF_ERROR(miopenSetOpArgsBiasForward(fusion_info.fusion_args,
|
||||
fusion_info.bias_b_op,
|
||||
&alpha,
|
||||
&beta,
|
||||
Base::s_.b_data));
|
||||
}
|
||||
if (activation_desc_) {
|
||||
const float relu_notused = 0.0;
|
||||
MIOPEN_RETURN_IF_ERROR(miopenSetOpArgsActivForward(fusion_info.fusion_args,
|
||||
fusion_info.act_op,
|
||||
&alpha,
|
||||
&beta,
|
||||
relu_notused,
|
||||
relu_notused,
|
||||
relu_notused));
|
||||
}
|
||||
fusion_status = miopenExecuteFusionPlan(Base::MiopenHandle(),
|
||||
fusion_info.plan,
|
||||
Base::s_.x_tensor,
|
||||
Base::s_.x_data,
|
||||
Base::s_.y_tensor,
|
||||
Base::s_.y_data,
|
||||
fusion_info.fusion_args);
|
||||
}
|
||||
if (miopenStatusSuccess != fusion_status) {
|
||||
MIOPEN_RETURN_IF_ERROR(miopenConvolutionForward(Base::MiopenHandle(),
|
||||
&alpha,
|
||||
Base::s_.x_tensor,
|
||||
Base::s_.x_data,
|
||||
Base::s_.w_desc,
|
||||
Base::s_.w_data,
|
||||
Base::s_.conv_desc,
|
||||
Base::s_.fwd_algo,
|
||||
&beta,
|
||||
Base::s_.y_tensor,
|
||||
Base::s_.y_data,
|
||||
workspace.get(),
|
||||
Base::s_.workspace_bytes));
|
||||
if (has_b) {
|
||||
MIOPEN_RETURN_IF_ERROR(_miopenAddTensor(Base::MiopenHandle(),
|
||||
&alpha, Base::s_.b_tensor, Base::s_.b_data,
|
||||
&alpha, Base::s_.y_tensor, Base::s_.y_data,
|
||||
&beta));
|
||||
}
|
||||
if (has_z) {
|
||||
MIOPEN_RETURN_IF_ERROR(_miopenAddTensor(Base::MiopenHandle(),
|
||||
&alpha, Base::s_.z_tensor, Base::s_.z_data,
|
||||
&alpha, Base::s_.y_tensor, Base::s_.y_data,
|
||||
&beta));
|
||||
}
|
||||
MIOPEN_RETURN_IF_ERROR(miopenActivationForward(Base::MiopenHandle(),
|
||||
activation_desc_,
|
||||
&alpha,
|
||||
Base::s_.y_tensor,
|
||||
Base::s_.y_data,
|
||||
&beta,
|
||||
Base::s_.y_tensor,
|
||||
Base::s_.y_data));
|
||||
}
|
||||
if (Base::s_.post_slicing_required) {
|
||||
ORT_RETURN_IF_ERROR(onnxruntime::rocm::SliceOutUnwantedOutputSection(
|
||||
this->Stream(),
|
||||
Base::s_.y_data,
|
||||
Base::s_.y_dims_with_adjusted_pads,
|
||||
Base::s_.Y->MutableDataRaw(),
|
||||
Base::s_.y_dims.GetDims(),
|
||||
Base::s_.slice_starts,
|
||||
Base::s_.slice_ends,
|
||||
Base::s_.slice_axes,
|
||||
Base::s_.element_size));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
Status MapMode(const std::string& activaton_mode) {
|
||||
if (activaton_mode == "Relu") {
|
||||
activation_mode_ = miopenActivationMode_t::miopenActivationRELU;
|
||||
} else {
|
||||
return Status(common::StatusCategory::ONNXRUNTIME,
|
||||
common::StatusCode::INVALID_ARGUMENT,
|
||||
"unsupported conv activation mode");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
miopenStatus_t status_ = miopenStatusNotInitialized;
|
||||
miopenActivationMode_t activation_mode_;
|
||||
miopenActivationDescriptor_t activation_desc_ = nullptr;
|
||||
|
||||
// MIOpen Fusion API
|
||||
// TODO: create one fusion descriptor shared by multiple FusedConv
|
||||
// objects
|
||||
//
|
||||
// Considerations:
|
||||
// How to determine two FusedConv objects may share the same fusion
|
||||
// descriptor? Hashing x_tensor,conv_desc, etc.?
|
||||
struct FusedConvFusionData {
|
||||
miopenFusionPlanDescriptor_t plan = nullptr;
|
||||
miopenFusionOpDescriptor_t conv_op = nullptr;
|
||||
miopenFusionOpDescriptor_t bias_b_op = nullptr;
|
||||
miopenFusionOpDescriptor_t bias_z_op = nullptr;
|
||||
miopenFusionOpDescriptor_t act_op = nullptr;
|
||||
miopenOperatorArgs_t fusion_args = nullptr;
|
||||
|
||||
// TODO: There is a potential problem. miopenHandle_t may be destroyed and
|
||||
// re-created later, sharing the same address. Currently there is any way
|
||||
// to detect it?
|
||||
mutable std::unordered_set<miopenHandle_t> compiled_on;
|
||||
|
||||
FusedConvFusionData(const FusedConvFusionData&) = delete;
|
||||
FusedConvFusionData& operator= (const FusedConvFusionData&) = delete;
|
||||
|
||||
FusedConvFusionData(FusedConvFusionData&& other) {
|
||||
*this = std::move(other);
|
||||
}
|
||||
FusedConvFusionData& operator=(FusedConvFusionData&& other) {
|
||||
std::swap(this->plan, other.plan);
|
||||
std::swap(this->fusion_args, other.fusion_args);
|
||||
this->conv_op = other.conv_op;
|
||||
this->bias_b_op = other.bias_b_op;
|
||||
this->bias_z_op = other.bias_z_op;
|
||||
this->act_op = other.act_op;
|
||||
this->compiled_on = std::move(other.compiled_on);
|
||||
return *this;
|
||||
}
|
||||
|
||||
FusedConvFusionData() { }
|
||||
~FusedConvFusionData() {
|
||||
if (plan) {
|
||||
miopenDestroyFusionPlan(plan);
|
||||
}
|
||||
if (fusion_args) {
|
||||
miopenDestroyOperatorArgs(fusion_args);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct FusionPlanCacheItem {
|
||||
std::unique_ptr<FusedConvFusionData> fusion;
|
||||
Status creation_result;
|
||||
// TODO: Add a timestamp for eviction
|
||||
// std::chrono::time_point<std::chrono::high_resolution_clock> last_access;
|
||||
|
||||
FusionPlanCacheItem() {
|
||||
}
|
||||
|
||||
miopenStatus_t CompileOnHandle(miopenHandle_t handle) const {
|
||||
if (!fusion->plan) {
|
||||
return miopenStatusNotInitialized;
|
||||
}
|
||||
auto iter = fusion->compiled_on.find(handle);
|
||||
if (iter != fusion->compiled_on.end()) {
|
||||
return miopenStatusSuccess;
|
||||
}
|
||||
auto ret = miopenCompileFusionPlan(handle, fusion->plan);
|
||||
if (miopenStatusSuccess == ret) {
|
||||
fusion->compiled_on.insert(handle);
|
||||
}
|
||||
return miopenStatusSuccess;
|
||||
}
|
||||
|
||||
bool Validate(miopenHandle_t handle) const {
|
||||
if (Status::OK() != creation_result) {
|
||||
return false;
|
||||
}
|
||||
if (!fusion || !fusion->plan || !fusion->fusion_args) {
|
||||
return false;
|
||||
}
|
||||
auto compiling_status = CompileOnHandle(handle);
|
||||
if (miopenStatusSuccess != compiling_status) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
struct FusionPlanCache {
|
||||
mutable OrtMutex mutex;
|
||||
using HashKey = uint32_t;
|
||||
std::unordered_map<HashKey, FusionPlanCacheItem> cache_directory_;
|
||||
|
||||
FusionPlanCache() {
|
||||
}
|
||||
|
||||
FusionPlanCacheItem& FindOrCreateFusionPlanCache(HashKey key,
|
||||
std::function<Status(FusedConvFusionData& fusion)> factory) {
|
||||
std::lock_guard<OrtMutex> lock(mutex);
|
||||
auto iter = cache_directory_.find(key);
|
||||
if (iter == cache_directory_.end()) {
|
||||
cache_directory_[key].fusion = std::make_unique<FusedConvFusionData>();
|
||||
cache_directory_[key].creation_result = factory(*cache_directory_[key].fusion);
|
||||
if (Status::OK() != cache_directory_[key].creation_result) {
|
||||
cache_directory_[key].fusion.reset();
|
||||
}
|
||||
}
|
||||
return cache_directory_[key];
|
||||
}
|
||||
};
|
||||
|
||||
static FusionPlanCache plan_cache_;
|
||||
|
||||
Status DoCreateFusionDesc(FusedConvFusionData& fusion) const {
|
||||
bool has_z = nullptr != Base::s_.z_data;
|
||||
bool has_b = nullptr != Base::s_.b_data;
|
||||
MIOPEN_RETURN_IF_ERROR(miopenCreateFusionPlan(&fusion.plan,
|
||||
miopenVerticalFusion,
|
||||
Base::s_.x_tensor));
|
||||
MIOPEN_RETURN_IF_ERROR(miopenCreateOperatorArgs(&fusion.fusion_args));
|
||||
MIOPEN_RETURN_IF_ERROR(miopenCreateOpConvForward(fusion.plan,
|
||||
&fusion.conv_op,
|
||||
Base::s_.conv_desc,
|
||||
Base::s_.w_desc));
|
||||
if (has_z) {
|
||||
MIOPEN_RETURN_IF_ERROR(miopenCreateOpBiasForward(fusion.plan,
|
||||
&fusion.bias_z_op,
|
||||
Base::s_.z_tensor));
|
||||
}
|
||||
if (has_b) {
|
||||
MIOPEN_RETURN_IF_ERROR(miopenCreateOpBiasForward(fusion.plan,
|
||||
&fusion.bias_b_op,
|
||||
Base::s_.b_tensor));
|
||||
}
|
||||
if (activation_desc_) {
|
||||
MIOPEN_RETURN_IF_ERROR(miopenCreateOpActivationForward(fusion.plan,
|
||||
&fusion.act_op,
|
||||
activation_mode_));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
uint32_t Hash() const {
|
||||
FNVHash hash;
|
||||
bool has_z = nullptr != Base::s_.z_data;
|
||||
bool has_b = nullptr != Base::s_.b_data;
|
||||
hash.HashTensor(Base::s_.x_tensor);
|
||||
hash.HashConvolutionDescriptor(Base::s_.conv_desc);
|
||||
hash.HashTensor(Base::s_.w_desc);
|
||||
if (has_z) {
|
||||
hash.HashTensor(Base::s_.z_tensor);
|
||||
}
|
||||
if (has_b) {
|
||||
hash.HashTensor(Base::s_.b_tensor);
|
||||
}
|
||||
if (activation_desc_) {
|
||||
hash << static_cast<int32_t>(activation_mode_);
|
||||
}
|
||||
return hash.GetValue();
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
typename FusedConv<T>::FusionPlanCache FusedConv<T>::plan_cache_;
|
||||
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX(
|
||||
FusedConv,
|
||||
kMSDomain,
|
||||
1,
|
||||
float,
|
||||
kRocmExecutionProvider,
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
FusedConv<float>);
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -205,7 +205,8 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul)>, // backward compatibility
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, BFloat16_float_BFloat16, LayerNormalization)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FusedConv)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain,
|
||||
1, float, FusedConv)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GemmFastGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GemmFastGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, GemmFastGelu)>,
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ static std::unordered_set<std::string> providers_except_cpu = {
|
|||
kArmNNExecutionProvider,
|
||||
kRocmExecutionProvider};
|
||||
|
||||
static std::unordered_set<std::string> providers_except_cpu_cuda = {
|
||||
static std::unordered_set<std::string> providers_except_cpu_gpu = {
|
||||
kDnnlExecutionProvider,
|
||||
kOpenVINOExecutionProvider,
|
||||
kNupharExecutionProvider,
|
||||
|
|
@ -49,8 +49,7 @@ static std::unordered_set<std::string> providers_except_cpu_cuda = {
|
|||
kDmlExecutionProvider,
|
||||
kMIGraphXExecutionProvider,
|
||||
kAclExecutionProvider,
|
||||
kArmNNExecutionProvider,
|
||||
kRocmExecutionProvider};
|
||||
kArmNNExecutionProvider};
|
||||
|
||||
|
||||
void TestConvOp(const ConvOpAndTestAttributes& attributes,
|
||||
|
|
@ -58,7 +57,7 @@ void TestConvOp(const ConvOpAndTestAttributes& attributes,
|
|||
const vector<vector<int64_t>>& input_shapes,
|
||||
const std::initializer_list<float>& expected_output,
|
||||
const vector<int64_t>& expected_output_shape,
|
||||
const std::unordered_set<std::string>& excluded_provider_types = providers_except_cpu_cuda,
|
||||
const std::unordered_set<std::string>& excluded_provider_types = providers_except_cpu_gpu,
|
||||
bool weight_is_initializer = false,
|
||||
OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess,
|
||||
const std::string& err_str = "") {
|
||||
|
|
@ -162,9 +161,9 @@ TEST(FusedConvTest, Conv2D_Bias_Relu) {
|
|||
TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape);
|
||||
}
|
||||
|
||||
#if defined(USE_CUDA)
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
|
||||
static std::unordered_set<std::string> providers_except_cuda = {
|
||||
static std::unordered_set<std::string> providers_except_gpu = {
|
||||
kCpuExecutionProvider,
|
||||
kDnnlExecutionProvider,
|
||||
kOpenVINOExecutionProvider,
|
||||
|
|
@ -177,8 +176,7 @@ static std::unordered_set<std::string> providers_except_cuda = {
|
|||
kDmlExecutionProvider,
|
||||
kMIGraphXExecutionProvider,
|
||||
kAclExecutionProvider,
|
||||
kArmNNExecutionProvider,
|
||||
kRocmExecutionProvider};
|
||||
kArmNNExecutionProvider};
|
||||
|
||||
TEST(FusedConvTest, Conv2D_Bias_Z_Relu) {
|
||||
ConvOpAndTestAttributes attrs = {
|
||||
|
|
@ -201,7 +199,7 @@ TEST(FusedConvTest, Conv2D_Bias_Z_Relu) {
|
|||
vector<float> Z = {-1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f};
|
||||
vector<int64_t> Z_shape = {1, 2, 2, 2};
|
||||
auto expected_vals = {12.0f, 17.0f, 25.0f, 29.0f, 11.0f, 15.0f, 23.0f, 28.0f};
|
||||
TestConvOp(attrs, {X, W, B, Z}, {X_shape, W_shape, B_shape, Z_shape}, expected_vals, Y_shape, providers_except_cuda);
|
||||
TestConvOp(attrs, {X, W, B, Z}, {X_shape, W_shape, B_shape, Z_shape}, expected_vals, Y_shape, providers_except_gpu);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
|
|||
Loading…
Reference in a new issue