mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Add Gemm op to DNNL Exectution provider (#8799)
* Implement Gemm op for DNNL execution provider Signed-off-by: George Nash <george.nash@intel.com> * Remove KernelRegistry and Gemm op for dnnl ep The KernelRegistry for the dnnl execution provider only registered a Gemm op that as best we can tell was never actually used and also was not using the dnnl library. We have implemented a Gemm op in the DNNL execution provider subgraph code and thus are removing the unused Gemm op that was in the dnnl KernelRegistry. Signed-off-by: George Nash <george.nash@intel.com> * Fix duplicated output and kernelshape inference fix getcapability to make sure subgraph outputs do not have duplicates fix kernelshape inference in pool Signed-off-by: Wang <zhaoyang.wang@intel.com> * Removed most dnnl specialized ifdefs from gradient_ops_test code Re-enable GlobalAveragePoolGrad test for dnnl ep The bugs that were exposed by the GlobalAveragePoolGrad test have been fixed and this test no longer needs to be disabled for DNNL. Removed the ReluGradDnnl test. We are getting the testing from the already existing ReluGrad test. MaxPoolGrad test no longer has specialized execution provider enabling for DNNL execution provider. It will now run without the extra enabling. ConvGrad is the only test that still has dnnl specialized ifdefs However, the ConvGrad code was not being executed by the code unless it was listed first in the list of execution providers. Signed-off-by: George Nash <george.nash@intel.com> * Fix transpose issue on Gemm On transposing square matrices, getmemoryandreshape will fail to reshape fix by adding a bool Signed-off-by: Wang <zhaoyang.wang@intel.com> * Save memory space by reusing internal tensor for output The intermediat matmul output tensor can be used as the output tensor for the binary calculation. Remove the unused IsAttributeSupported from the DnnlGemmNodeCapability class since we now support all of the Gemm attributes in our implementation. Signed-off-by: George Nash <george.nash@intel.com> Co-authored-by: Wang <zhaoyang.wang@intel.com>
This commit is contained in:
parent
89656bb712
commit
d4a88cfe3f
17 changed files with 331 additions and 321 deletions
|
|
@ -54,40 +54,6 @@ DNNLExecutionProvider::DNNLExecutionProvider(const DNNLExecutionProviderInfo& in
|
|||
DNNLExecutionProvider::~DNNLExecutionProvider() {
|
||||
}
|
||||
|
||||
namespace ort_dnnl {
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDnnlExecutionProvider, kOnnxDomain, 7, Gemm);
|
||||
|
||||
Status RegisterDNNLKernels(KernelRegistry& kernel_registry) {
|
||||
static const BuildKernelCreateInfoFn function_table[] = {
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kDnnlExecutionProvider, kOnnxDomain, 7, Gemm)>,
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
ORT_RETURN_IF_ERROR(kernel_registry.Register(function_table_entry()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace ort_dnnl
|
||||
|
||||
static std::shared_ptr<onnxruntime::KernelRegistry> s_kernel_registry;
|
||||
|
||||
void Shutdown_DeleteRegistry() {
|
||||
s_kernel_registry.reset();
|
||||
}
|
||||
|
||||
std::shared_ptr<KernelRegistry> DNNLExecutionProvider::GetKernelRegistry() const {
|
||||
if (!s_kernel_registry) {
|
||||
s_kernel_registry = KernelRegistry::Create();
|
||||
auto status = ort_dnnl::RegisterDNNLKernels(*s_kernel_registry);
|
||||
if (!status.IsOK())
|
||||
s_kernel_registry.reset();
|
||||
ORT_THROW_IF_ERROR(status);
|
||||
}
|
||||
|
||||
return s_kernel_registry;
|
||||
}
|
||||
|
||||
std::vector<std::vector<NodeIndex>> DNNLExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer) const {
|
||||
std::vector<std::vector<size_t>> supported_node_vecs;
|
||||
std::vector<size_t> supported_node_vec;
|
||||
|
|
@ -220,7 +186,7 @@ std::vector<std::unique_ptr<ComputeCapability>> DNNLExecutionProvider::GetCapabi
|
|||
for (auto it = node->OutputEdgesBegin(), end = node->OutputEdgesEnd(); it != end; ++it) {
|
||||
if (node_set.count(it->GetNode().Index()) == 0) {
|
||||
const auto* output_def = output_defs[it->GetSrcArgIndex()];
|
||||
if (subgraph_outputs.count(output_def) == 0) {
|
||||
if (subgraph_outputs.count(output_def) == 0 && graph_outputs.count(output_def) == 0) {
|
||||
subgraph_outputs.insert(output_def);
|
||||
ordered_subgraph_outputs.push_back(output_def);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -30,8 +30,6 @@ class DNNLExecutionProvider : public IExecutionProvider {
|
|||
explicit DNNLExecutionProvider(const DNNLExecutionProviderInfo& info);
|
||||
virtual ~DNNLExecutionProvider();
|
||||
|
||||
virtual std::shared_ptr<KernelRegistry> GetKernelRegistry() const override;
|
||||
|
||||
std::vector<std::unique_ptr<ComputeCapability>>
|
||||
GetCapability(const onnxruntime::GraphViewer& graph,
|
||||
const std::vector<const KernelRegistry*>& /*kernel_registries*/) const override;
|
||||
|
|
|
|||
|
|
@ -339,6 +339,8 @@ bool DnnlSumNodeCapability::IsDimensionSupported(const Node* node) const {
|
|||
return true;
|
||||
}
|
||||
|
||||
// DnnlBinaryNodeCapability class
|
||||
//-------------------------------------
|
||||
bool DnnlBinaryNodeCapability::Supported(const Node* node) const {
|
||||
if (!IsTypeSupported(node)) return false;
|
||||
if (!IsDimensionSupported(node)) return false;
|
||||
|
|
@ -370,4 +372,12 @@ bool DnnlBinaryNodeCapability::IsDimensionSupported(const Node* node) const {
|
|||
return true;
|
||||
}
|
||||
|
||||
// DnnlGemmNodeCapability class
|
||||
//-------------------------------------
|
||||
bool DnnlGemmNodeCapability::Supported(const Node* node) const {
|
||||
if (!_matmul.Supported(node)) return false;
|
||||
if (!_binary.Supported(node)) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -207,4 +207,18 @@ class DnnlBinaryNodeCapability : public DnnlDefaultNodeCapability {
|
|||
bool IsDimensionSupported(const Node* node) const;
|
||||
};
|
||||
|
||||
/**
|
||||
* Decide if a Gemm op is supported by DnnlExecutionProvider
|
||||
*/
|
||||
class DnnlGemmNodeCapability : public DnnlDefaultNodeCapability {
|
||||
public:
|
||||
DnnlGemmNodeCapability() : DnnlDefaultNodeCapability({"float"}) {}
|
||||
|
||||
bool Supported(const Node* node) const override;
|
||||
|
||||
private:
|
||||
DnnlMatMulNodeCapability _matmul;
|
||||
DnnlBinaryNodeCapability _binary;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ DnnlOpManager::DnnlOpManager() {
|
|||
dnnl_ops_map_.emplace(std::make_pair("BatchNormalization", std::unique_ptr<DnnlNodeCapability>(new DnnlBatchNormalizationNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("Conv", std::unique_ptr<DnnlNodeCapability>(new DnnlDefaultNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("Div", std::unique_ptr<DnnlNodeCapability>(new DnnlBinaryNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("Gemm", std::unique_ptr<DnnlNodeCapability>(new DnnlGemmNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("GlobalAveragePool", std::unique_ptr<DnnlNodeCapability>(new DnnlPoolNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("GlobalMaxPool", std::unique_ptr<DnnlNodeCapability>(new DnnlPoolNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("LRN", std::unique_ptr<DnnlNodeCapability>(new DnnlDefaultNodeCapability())));
|
||||
|
|
|
|||
|
|
@ -11,8 +11,6 @@ using namespace onnxruntime;
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
void Shutdown_DeleteRegistry();
|
||||
|
||||
struct DnnlProviderFactory : IExecutionProviderFactory {
|
||||
DnnlProviderFactory(bool create_arena) : create_arena_(create_arena) {}
|
||||
~DnnlProviderFactory() override {}
|
||||
|
|
@ -50,7 +48,7 @@ struct Dnnl_Provider : Provider {
|
|||
}
|
||||
|
||||
void Shutdown() override {
|
||||
Shutdown_DeleteRegistry();
|
||||
return;
|
||||
}
|
||||
|
||||
} g_provider;
|
||||
|
|
|
|||
|
|
@ -1,203 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/shared_library/provider_api.h"
|
||||
#include "gemm.h"
|
||||
#include "dnnl.h"
|
||||
#include "dnnl.hpp"
|
||||
#include "core/providers/dnnl/dnnl_fwd.h"
|
||||
#include "gsl/gsl"
|
||||
#include "onnxruntime_config.h"
|
||||
// build/external/eigen/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h:162:71:
|
||||
// error: ignoring attributes on template argument "Eigen::PacketType<const float, Eigen::DefaultDevice>::type {aka __vector(4) float}" [-Werror=ignored-attributes]
|
||||
#if defined(__GNUC__)
|
||||
#pragma GCC diagnostic push
|
||||
#if __GNUC__ >= 6
|
||||
#pragma GCC diagnostic ignored "-Wignored-attributes"
|
||||
#endif
|
||||
#pragma GCC diagnostic ignored "-Wunused-parameter"
|
||||
#ifdef HAS_DEPRECATED_COPY
|
||||
#pragma GCC diagnostic ignored "-Wdeprecated-copy"
|
||||
#endif
|
||||
#elif defined(_MSC_VER)
|
||||
// build\windows\debug\external\eigen3\unsupported\eigen\cxx11\src/Tensor/Tensor.h(76):
|
||||
// warning C4554: '&': check operator precedence for possible error; use parentheses to clarify precedence
|
||||
|
||||
// unsupported\eigen\cxx11\src\Tensor\TensorUInt128.h(150,0): Warning C4245: 'initializing': conversion from '__int64'
|
||||
// to 'uint64_t', signed/unsigned mismatch
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable : 4554)
|
||||
#pragma warning(disable : 4245)
|
||||
#pragma warning(disable : 4127)
|
||||
#endif
|
||||
|
||||
#include "Eigen/Core"
|
||||
|
||||
#if defined(__GNUC__)
|
||||
#pragma GCC diagnostic pop
|
||||
#elif defined(_MSC_VER)
|
||||
#pragma warning(pop)
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace ort_dnnl {
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
Gemm,
|
||||
kOnnxDomain,
|
||||
7,
|
||||
kDnnlExecutionProvider,
|
||||
KernelDefBuilder::Create()->TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Gemm<float>);
|
||||
|
||||
class GemmHelper {
|
||||
public:
|
||||
GemmHelper(const TensorShape& left, bool trans_left, const TensorShape& right, bool trans_right, const TensorShape& bias) {
|
||||
ORT_ENFORCE(left.NumDimensions() == 2 || left.NumDimensions() == 1);
|
||||
ORT_ENFORCE(right.NumDimensions() == 2);
|
||||
|
||||
if (trans_left) {
|
||||
M_ = left.NumDimensions() == 2 ? left[1] : left[0];
|
||||
K_ = left.NumDimensions() == 2 ? left[0] : 1;
|
||||
} else {
|
||||
M_ = left.NumDimensions() == 2 ? left[0] : 1;
|
||||
K_ = left.NumDimensions() == 2 ? left[1] : left[0];
|
||||
}
|
||||
|
||||
int k_dim;
|
||||
if (trans_right) {
|
||||
N_ = right[0];
|
||||
k_dim = 1;
|
||||
} else {
|
||||
N_ = right[1];
|
||||
k_dim = 0;
|
||||
}
|
||||
|
||||
if (right[k_dim] != K_)
|
||||
status_ = ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"GEMM: Dimension mismatch, W: ",
|
||||
right.ToString(),
|
||||
" K: " + std::to_string(K_),
|
||||
" N:" + std::to_string(N_));
|
||||
|
||||
if (!IsValidBroadcast(bias, M_, N_))
|
||||
status_ = common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Gemm: Invalid bias shape for broadcast");
|
||||
|
||||
// it is possible the input is empty tensor, for example the output of roipool in fast rcnn.
|
||||
ORT_ENFORCE(M_ >= 0 && K_ > 0 && N_ >= 0);
|
||||
}
|
||||
|
||||
int64_t M() const { return M_; }
|
||||
int64_t N() const { return N_; }
|
||||
int64_t K() const { return K_; }
|
||||
Status State() const { return status_; }
|
||||
|
||||
private:
|
||||
bool IsValidBroadcast(const TensorShape& bias_shape, int64_t M, int64_t N) {
|
||||
// valid shapes are (,) , (1, N) , (M, 1) , (M, N)
|
||||
if (bias_shape.NumDimensions() > 2)
|
||||
return false;
|
||||
// shape is (1,) or (1, 1), or (,)
|
||||
if (bias_shape.Size() == 1)
|
||||
return true;
|
||||
// valid bias_shape (s) are (N,) or (1, N) or (M, 1) or (M, N),
|
||||
// In last case no broadcasting needed, so don't fail it
|
||||
return ((bias_shape.NumDimensions() == 1 && bias_shape[0] == N) ||
|
||||
(bias_shape.NumDimensions() == 2 && bias_shape[0] == M && (bias_shape[1] == 1 || bias_shape[1] == N)) ||
|
||||
(bias_shape.NumDimensions() == 2 && bias_shape[0] == 1 && bias_shape[1] == N));
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t M_;
|
||||
int64_t K_;
|
||||
int64_t N_;
|
||||
Status status_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using EigenMatrixMapRowMajor = Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
|
||||
template <typename T>
|
||||
using ConstEigenVectorMap = Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, 1>>;
|
||||
template <typename T>
|
||||
using ConstEigenMatrixMapRowMajor = Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
|
||||
|
||||
template <>
|
||||
Status Gemm<float>::Compute(OpKernelContext* ctx) const {
|
||||
const auto X = ctx->Input<Tensor>(0);
|
||||
const auto W = ctx->Input<Tensor>(1);
|
||||
const auto B = ctx->Input<Tensor>(2);
|
||||
GemmHelper helper(X->Shape(), trans_A_, W->Shape(), trans_B_, B->Shape());
|
||||
|
||||
if (!helper.State().IsOK())
|
||||
return helper.State();
|
||||
|
||||
dnnl::memory::dim M = gsl::narrow_cast<int>(helper.M());
|
||||
dnnl::memory::dim N = gsl::narrow_cast<int>(helper.N());
|
||||
dnnl::memory::dim K = gsl::narrow_cast<int>(helper.K());
|
||||
auto Y = ctx->Output(0, TensorShape({M, N}));
|
||||
|
||||
if (M <= 0)
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Empty Tensor not supported");
|
||||
|
||||
if (beta_ != 0) {
|
||||
auto output_mat = EigenMatrixMapRowMajor<float>(
|
||||
Y->template MutableData<float>(),
|
||||
M,
|
||||
N);
|
||||
output_mat.setZero();
|
||||
|
||||
auto& b_shape = B->Shape();
|
||||
// if B is (), (1,) or (1, 1), add the scalar
|
||||
if (b_shape.Size() == 1) {
|
||||
output_mat.array() += *(B->template Data<float>());
|
||||
}
|
||||
// B is (N,)
|
||||
else if (b_shape.NumDimensions() == 1) {
|
||||
auto bias_vec = ConstEigenVectorMap<float>(
|
||||
B->template Data<float>(),
|
||||
N);
|
||||
output_mat.rowwise() += bias_vec.transpose();
|
||||
} else if (b_shape.NumDimensions() == 2) {
|
||||
// B is (M, 1)
|
||||
if (b_shape[1] == 1) {
|
||||
auto bias_vec = ConstEigenVectorMap<float>(
|
||||
B->template Data<float>(),
|
||||
M);
|
||||
output_mat.colwise() += bias_vec;
|
||||
}
|
||||
// B is (1, N)
|
||||
else if (b_shape[0] == 1) {
|
||||
auto bias_vec = ConstEigenVectorMap<float>(
|
||||
B->template Data<float>(),
|
||||
N);
|
||||
output_mat.rowwise() += bias_vec.transpose();
|
||||
}
|
||||
// B is (M, N), no broadcast needed.
|
||||
else {
|
||||
auto bias_mat = ConstEigenMatrixMapRowMajor<float>(
|
||||
B->template Data<float>(),
|
||||
M,
|
||||
N);
|
||||
output_mat += bias_mat;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// dnnl_sgemm expects row major matrices, so no need to swap the operands A and B
|
||||
auto status = dnnl_sgemm(trans_A_ ? 'T' : 'N',
|
||||
trans_B_ ? 'T' : 'N',
|
||||
M, N, K,
|
||||
alpha_, X->template Data<float>(), trans_A_ ? M : K,
|
||||
W->template Data<float>(), trans_B_ ? K : N,
|
||||
beta_, Y->template MutableData<float>(), N);
|
||||
if (status == dnnl_success) {
|
||||
return Status::OK();
|
||||
} else {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "DNNL_sgemm failed with status: ", status);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ort_dnnl
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1,32 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace ort_dnnl {
|
||||
template <typename T>
|
||||
class Gemm final : public OpKernel {
|
||||
public:
|
||||
Gemm(const OpKernelInfo& info) : OpKernel(info) {
|
||||
int64_t temp;
|
||||
ORT_ENFORCE(info.GetAttr<int64_t>("transA", &temp).IsOK());
|
||||
trans_A_ = (temp != 0);
|
||||
|
||||
ORT_ENFORCE(info.GetAttr<int64_t>("transB", &temp).IsOK());
|
||||
trans_B_ = (temp != 0);
|
||||
|
||||
ORT_ENFORCE(info.GetAttr<float>("alpha", &alpha_).IsOK());
|
||||
ORT_ENFORCE(info.GetAttr<float>("beta", &beta_).IsOK());
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
bool trans_A_;
|
||||
bool trans_B_;
|
||||
float alpha_;
|
||||
float beta_;
|
||||
};
|
||||
} // namespace ort_dnnl
|
||||
} // namespace onnxruntime
|
||||
181
onnxruntime/core/providers/dnnl/subgraph/dnnl_gemm.cc
Normal file
181
onnxruntime/core/providers/dnnl/subgraph/dnnl_gemm.cc
Normal file
|
|
@ -0,0 +1,181 @@
|
|||
// Copyright(C) 2021 Intel Corporation
|
||||
// Licensed under the MIT License
|
||||
|
||||
#include "dnnl_gemm.h"
|
||||
#include "dnnl_subgraph.h"
|
||||
#include "dnnl_subgraph_primitive.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace ort_dnnl {
|
||||
|
||||
DnnlGemm::DnnlGemm() {}
|
||||
|
||||
/*
|
||||
Gemm implementation:
|
||||
Gemm:
|
||||
Inputs:
|
||||
0) A - Input Tensor
|
||||
1) B - Input Tensor
|
||||
2) C - Input Tensor (optional if Opset is 11 or later)
|
||||
Outputs:
|
||||
0) Y - Output Tensor
|
||||
|
||||
+-----------+
|
||||
(A) | |
|
||||
---------->+ | AB +------+
|
||||
(B) | MatMul +--------------------->+ | alphaAB
|
||||
---------->+ | (alpha) | Mul +---+
|
||||
| | *--------------->+ | | +------+
|
||||
+-----------+ +------+ +---->+ | (Y) alphaAB + betaC
|
||||
| Add +---------------------->
|
||||
(C) +------+ +---->+ |
|
||||
--------------------------------------------->+ | | +------+
|
||||
(beta) | Mul +---+
|
||||
*--------------->+ | betaC
|
||||
+------+
|
||||
|
||||
Attributes (alpha, beta, transA, transB)
|
||||
|
||||
To compose Gemm: (algorithm)
|
||||
(1) perform `MatMul` on input tensors A and B result (AB)
|
||||
(2) if `Mul` the result of (1) by alpha attribute (alphaAB)
|
||||
(3) if C is optional return result from (2) and end
|
||||
(4) if C is avalible `Mul` input C tensor by beta attribute (betaC)
|
||||
(5) `Add` result from (2) to result from (4) (alphaAB + betaC)
|
||||
(6) Return output from (5) and end
|
||||
|
||||
OneDNN algorithm:
|
||||
(1) perform `MatMul` of tensor A and tensor B with `Output scales` set to alpha (0)
|
||||
(2) if C is optional return output from (1) and end
|
||||
(3) if C is avalible perform binary `Add` of output from (0) and input C with input C's `scale` attribute set to beta
|
||||
(4) return output from (4) and end
|
||||
|
||||
*/
|
||||
|
||||
|
||||
void DnnlGemm::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) {
|
||||
auto eng = sp.GetEngine();
|
||||
|
||||
auto a_dims = sp.GetMemory(node.Input(IN_A).Name()).get_desc().dims();
|
||||
auto b_dims = sp.GetMemory(node.Input(IN_B).Name()).get_desc().dims();
|
||||
|
||||
bool input_c_exists = node.Input(IN_C).Exists();
|
||||
|
||||
if (a_dims.size() != b_dims.size()) {
|
||||
while (a_dims.size() < b_dims.size()) {
|
||||
a_dims.insert(a_dims.begin(), 1);
|
||||
}
|
||||
while (a_dims.size() > b_dims.size()) {
|
||||
b_dims.insert(b_dims.begin(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
dnnl::memory::desc a_md;
|
||||
dnnl::memory::desc b_md;
|
||||
|
||||
bool transA = GetTransA(node);
|
||||
bool transB = GetTransB(node);
|
||||
|
||||
dnnl::memory::dim M = (transA) ? a_dims[1] : a_dims[0];
|
||||
dnnl::memory::dim K = (transA) ? a_dims[0] : a_dims[1];
|
||||
dnnl::memory::dim N = (transB) ? b_dims[0] : b_dims[1];
|
||||
|
||||
dnnl::memory::dims a_strides = (transA) ? dnnl::memory::dims{dnnl::memory::dim(1), M} : dnnl::memory::dims{K, dnnl::memory::dim(1)};
|
||||
dnnl::memory::dims b_strides = (transB) ? dnnl::memory::dims{dnnl::memory::dim(1), K} : dnnl::memory::dims{N, dnnl::memory::dim(1)};
|
||||
|
||||
a_md = dnnl::memory::desc({M, K}, node.Input(IN_A).Type(), a_strides);
|
||||
b_md = dnnl::memory::desc({K, N}, node.Input(IN_B).Type(), b_strides);
|
||||
|
||||
dnnl::memory::dims output_shape{M, N};
|
||||
|
||||
dnnl::primitive_attr matmul_attr;
|
||||
// scale the output from MatMul to alpha
|
||||
float alpha = GetAlpha(node);
|
||||
std::vector<float> alphaScale({alpha});
|
||||
matmul_attr.set_output_scales(0, alphaScale);
|
||||
|
||||
auto matmul_dst_md = dnnl::memory::desc(output_shape, node.Output(OUT_Y).Type(), {N, 1});
|
||||
|
||||
auto matmul_d = dnnl::matmul::desc(a_md, b_md, matmul_dst_md);
|
||||
dnnl::matmul::primitive_desc matmul_pd;
|
||||
matmul_pd = dnnl::matmul::primitive_desc(matmul_d, matmul_attr, eng);
|
||||
|
||||
auto matmul_a_mem = sp.GetMemoryAndReshape(node.Input(IN_A), matmul_pd.src_desc(), eng, transA);
|
||||
auto matmul_b_mem = sp.GetMemoryAndReshape(node.Input(IN_B), matmul_pd.weights_desc(), eng, transB);
|
||||
auto gemm_dst_mem = dnnl::memory(matmul_pd.dst_desc(), eng);
|
||||
|
||||
auto matmul_op = dnnl::matmul(matmul_pd);
|
||||
|
||||
std::unordered_map<int, dnnl::memory> args;
|
||||
args.insert({DNNL_ARG_SRC, matmul_a_mem});
|
||||
args.insert({DNNL_ARG_WEIGHTS, matmul_b_mem});
|
||||
args.insert({DNNL_ARG_DST, gemm_dst_mem});
|
||||
|
||||
sp.AddPrimitive(matmul_op, args);
|
||||
|
||||
if (input_c_exists) {
|
||||
auto c_original_md = sp.GetMemory(node.Input(IN_C).Name()).get_desc();
|
||||
auto c_dims = c_original_md.dims();
|
||||
if (c_dims.size() != a_dims.size()) {
|
||||
while (c_dims.size() < a_dims.size()) {
|
||||
c_dims.insert(c_dims.begin(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
auto c_md = c_original_md.reshape(c_dims);
|
||||
|
||||
auto y_md = dnnl::memory::desc(output_shape, node.Output(OUT_Y).Type(), dnnl::memory::format_tag::any);
|
||||
|
||||
auto binary_d = dnnl::binary::desc(dnnl::algorithm::binary_add, matmul_pd.dst_desc(), c_md, y_md);
|
||||
|
||||
// Scale input C by beta before adding it to the MatMul output.
|
||||
dnnl::primitive_attr binary_attr;
|
||||
float beta = GetBeta(node);
|
||||
binary_attr.set_scales(DNNL_ARG_SRC_1, 0, {beta});
|
||||
|
||||
auto binary_pd = dnnl::binary::primitive_desc(binary_d, binary_attr,eng);
|
||||
|
||||
auto binary_c_mem = sp.GetMemoryAndReshape(node.Input(IN_C), binary_pd.src1_desc(), eng);
|
||||
|
||||
auto binary_op = dnnl::binary(binary_pd);
|
||||
|
||||
sp.AddPrimitive(binary_op, {{DNNL_ARG_SRC_0, gemm_dst_mem},
|
||||
{DNNL_ARG_SRC_1, binary_c_mem},
|
||||
{DNNL_ARG_DST, gemm_dst_mem}});
|
||||
}
|
||||
sp.SetMemory(node.Output(OUT_Y), gemm_dst_mem);
|
||||
}
|
||||
|
||||
float DnnlGemm::GetAlpha(DnnlNode& node) {
|
||||
auto attr = node.Attributes().find("alpha");
|
||||
if (attr != node.Attributes().end()) {
|
||||
return attr->second().f();
|
||||
}
|
||||
return 1.0;
|
||||
}
|
||||
float DnnlGemm::GetBeta(DnnlNode& node) {
|
||||
auto attr = node.Attributes().find("beta");
|
||||
if (attr != node.Attributes().end()) {
|
||||
return attr->second().f();
|
||||
}
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
bool DnnlGemm::GetTransA(DnnlNode& node) {
|
||||
auto attr = node.Attributes().find("transA");
|
||||
if (attr != node.Attributes().end()) {
|
||||
return (attr->second().i() != 0);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool DnnlGemm::GetTransB(DnnlNode& node) {
|
||||
auto attr = node.Attributes().find("transB");
|
||||
if (attr != node.Attributes().end()) {
|
||||
return (attr->second().i() != 0);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace ort_dnnl
|
||||
} // namespace onnxruntime
|
||||
34
onnxruntime/core/providers/dnnl/subgraph/dnnl_gemm.h
Normal file
34
onnxruntime/core/providers/dnnl/subgraph/dnnl_gemm.h
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
// Copyright(C) 2021 Intel Corporation
|
||||
// Licensed under the MIT License
|
||||
|
||||
#pragma once
|
||||
#include "dnnl_subgraph.h"
|
||||
#include "dnnl_subgraph_primitive.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace ort_dnnl {
|
||||
|
||||
class DnnlGemm {
|
||||
public:
|
||||
enum InputTensors : int {
|
||||
IN_A = 0,
|
||||
IN_B = 1,
|
||||
IN_C = 2
|
||||
};
|
||||
|
||||
enum OutputTensors : int {
|
||||
OUT_Y = 0
|
||||
};
|
||||
|
||||
DnnlGemm();
|
||||
void CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node);
|
||||
|
||||
private:
|
||||
float GetAlpha(DnnlNode& node);
|
||||
float GetBeta(DnnlNode& node);
|
||||
bool GetTransA(DnnlNode& node);
|
||||
bool GetTransB(DnnlNode& node);
|
||||
};
|
||||
|
||||
} // namespace ort_dnnl
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -31,7 +31,7 @@ void DnnlPool::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) {
|
|||
}
|
||||
}
|
||||
|
||||
auto kernel_shape = GetKernelShape(node);
|
||||
auto kernel_shape = GetKernelShape(src_dims, node);
|
||||
PoolShape shape = static_cast<PoolShape>(kernel_shape.size());
|
||||
auto strides = GetStrides(node, shape);
|
||||
|
||||
|
|
@ -118,7 +118,7 @@ dnnl::memory::dims DnnlPool::GetDilations(DnnlNode& node, PoolShape shape) {
|
|||
return dnnl::memory::dims(dilations.begin(), dilations.end());
|
||||
}
|
||||
|
||||
dnnl::memory::dims DnnlPool::GetKernelShape(DnnlNode& node) {
|
||||
dnnl::memory::dims DnnlPool::GetKernelShape(const dnnl::memory::dims& src_dims, DnnlNode& node) {
|
||||
auto attr = node.Attributes().find("kernel_shape");
|
||||
std::vector<int64_t> kernel_shape;
|
||||
if (attr != node.Attributes().end()) {
|
||||
|
|
@ -128,9 +128,8 @@ dnnl::memory::dims DnnlPool::GetKernelShape(DnnlNode& node) {
|
|||
}
|
||||
return kernel_shape;
|
||||
}
|
||||
// Infer the Kernel shape from the input weights
|
||||
auto weight_dims = node.Input(IN_X).Dim();
|
||||
kernel_shape = std::vector<int64_t>(weight_dims.begin() + 2, weight_dims.end());
|
||||
|
||||
kernel_shape = std::vector<int64_t>(src_dims.begin() + 2, src_dims.end());
|
||||
return kernel_shape;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class DnnlPool {
|
|||
int64_t GetCeilMode(DnnlNode& node);
|
||||
int64_t GetCountIncludePadding(DnnlNode& node);
|
||||
dnnl::memory::dims GetDilations(DnnlNode& node, PoolShape shape);
|
||||
dnnl::memory::dims GetKernelShape(DnnlNode& node);
|
||||
dnnl::memory::dims GetKernelShape(const dnnl::memory::dims& src_dims, DnnlNode& node);
|
||||
/* This will return the calculated padding taking into account the DEPRECATED auto_pad attribute */
|
||||
std::vector<int64_t> InferPadding(DnnlNode& node, const dnnl::memory::dims& src_dims, const dnnl::memory::dims& kernel_shape, const dnnl::memory::dims& strides);
|
||||
std::vector<int64_t> GetPadding(DnnlNode& node, PoolShape shape);
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
#include "dnnl_batchnorm.h"
|
||||
#include "dnnl_binary.h"
|
||||
#include "dnnl_conv.h"
|
||||
#include "dnnl_gemm.h"
|
||||
#include "dnnl_lrn.h"
|
||||
#include "dnnl_matmul.h"
|
||||
#include "dnnl_matmul_integer.h"
|
||||
|
|
@ -205,6 +206,8 @@ void DnnlSubgraphPrimitive::AddKernels() {
|
|||
DnnlBatchNorm().CreatePrimitive(*this, node);
|
||||
} else if (node.OpType() == "Conv") {
|
||||
DnnlConv().CreatePrimitive(*this, node);
|
||||
} else if (node.OpType() == "Gemm") {
|
||||
DnnlGemm().CreatePrimitive(*this, node);
|
||||
} else if (node.OpType() == "LRN") {
|
||||
DnnlLrn().CreatePrimitive(*this, node);
|
||||
} else if (node.OpType() == "MatMul") {
|
||||
|
|
@ -356,7 +359,7 @@ void DnnlSubgraphPrimitive::SetInitializer(std::string memory_name, dnnl::memory
|
|||
}
|
||||
}
|
||||
|
||||
dnnl::memory DnnlSubgraphPrimitive::GetMemoryAndReshape(ort_dnnl::DnnlTensor tensor, dnnl::memory::desc mem_desc, dnnl::engine eng) {
|
||||
dnnl::memory DnnlSubgraphPrimitive::GetMemoryAndReshape(ort_dnnl::DnnlTensor tensor, dnnl::memory::desc mem_desc, dnnl::engine eng, bool transpose) {
|
||||
// if found just return
|
||||
if (HasMemory(tensor.Name(), mem_desc, eng)) {
|
||||
return GetMemory(tensor.Name(), mem_desc, eng);
|
||||
|
|
@ -372,7 +375,7 @@ dnnl::memory DnnlSubgraphPrimitive::GetMemoryAndReshape(ort_dnnl::DnnlTensor ten
|
|||
auto mem_to = dnnl::memory(mem_desc, eng);
|
||||
|
||||
// if it is a reshape, ensure reorder is possible by making the same dims
|
||||
if (mem_from.get_desc().dims() != mem_to.get_desc().dims()) {
|
||||
if (mem_from.get_desc().dims() != mem_to.get_desc().dims() || transpose) {
|
||||
auto mem_from_dims = mem_from.get_desc().dims();
|
||||
auto mem_to_dims = mem_to.get_desc().dims();
|
||||
if (Product(mem_from_dims) != Product(mem_to_dims)) {
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ class DnnlSubgraphPrimitive {
|
|||
dnnl::stream GetStream();
|
||||
|
||||
//obtain a dnnl::memory with specified name, memory descriptor and engine, will perform extra reorder/reshape if necessary before returning
|
||||
dnnl::memory GetMemoryAndReshape(ort_dnnl::DnnlTensor tensor, dnnl::memory::desc mem_desc, dnnl::engine eng);
|
||||
dnnl::memory GetMemoryAndReshape(ort_dnnl::DnnlTensor tensor, dnnl::memory::desc mem_desc, dnnl::engine eng, bool transpose = false);
|
||||
//add dnnl primitive and memory map to subgraph primitive
|
||||
void AddPrimitive(dnnl::primitive prim, std::unordered_map<int, dnnl::memory> mem_map);
|
||||
//add a reshape (e.g. squeeze, unsqueeze) to subgraph primitive
|
||||
|
|
|
|||
|
|
@ -225,6 +225,67 @@ TEST(GemmOpTest, GemmTransB_1) {
|
|||
TestGemmTransB_1<double>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TestGemmAlpha() {
|
||||
OpTester test("Gemm");
|
||||
|
||||
test.AddAttribute("transA", (int64_t)0);
|
||||
test.AddAttribute("transB", (int64_t)0);
|
||||
test.AddAttribute("alpha", 0.5f);
|
||||
test.AddAttribute("beta", 1.0f);
|
||||
|
||||
test.AddInput<T>("A", {2, 4},
|
||||
{1.0f, 2.0f, 3.0f, 4.0f,
|
||||
-1.0f, -2.0f, -3.0f, -4.0f});
|
||||
test.AddInput<T>("B", {4, 3}, std::vector<T>(12, 1.0f));
|
||||
test.AddInput<T>("C", {3}, std::vector<T>(3, 1.0f));
|
||||
test.AddOutput<T>("Y", {2, 3},
|
||||
{6.0f, 6.0f, 6.0f,
|
||||
-4.0f, -4.0f, -4.0f});
|
||||
//test.AddOutput<T>("Y", {2, 3},
|
||||
// {5.0f, 5.0f, 5.0f,
|
||||
// -5.0f, -5.0f, -5.0f});
|
||||
#if defined(OPENVINO_CONFIG_GPU_FP16) || defined(OPENVINO_CONFIG_GPU_FP32)
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues
|
||||
#else
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT: Seg fault in parser
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(GemmOpTest, GemmAlpha) {
|
||||
TestGemmAlpha<float>();
|
||||
TestGemmAlpha<double>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TestGemmBeta() {
|
||||
OpTester test("Gemm");
|
||||
|
||||
test.AddAttribute("transA", (int64_t)0);
|
||||
test.AddAttribute("transB", (int64_t)0);
|
||||
test.AddAttribute("alpha", 1.0f);
|
||||
test.AddAttribute("beta", 2.0f);
|
||||
|
||||
test.AddInput<T>("A", {2, 4},
|
||||
{1.0f, 2.0f, 3.0f, 4.0f,
|
||||
-1.0f, -2.0f, -3.0f, -4.0f});
|
||||
test.AddInput<T>("B", {4, 3}, std::vector<T>(12, 1.0f));
|
||||
test.AddInput<T>("C", {3}, std::vector<T>(3, 1.0f));
|
||||
test.AddOutput<T>("Y", {2, 3},
|
||||
{12.0f, 12.0f, 12.0f,
|
||||
-8.0f, -8.0f, -8.0f});
|
||||
#if defined(OPENVINO_CONFIG_GPU_FP16) || defined(OPENVINO_CONFIG_GPU_FP32)
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues
|
||||
#else
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT: Seg fault in parser
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(GemmOpTest, GemmBeta) {
|
||||
TestGemmBeta<float>();
|
||||
TestGemmBeta<double>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TestGemmAlphaBeta() {
|
||||
OpTester test("Gemm");
|
||||
|
|
|
|||
|
|
@ -192,6 +192,12 @@ void GradientOpTester::Run(
|
|||
|
||||
//if node is not registered for the provider, skip
|
||||
node.SetExecutionProviderType(provider_type);
|
||||
|
||||
// provider types that don't use the KernelRegistry
|
||||
if (provider_type == onnxruntime::kDnnlExecutionProvider) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto reg = execution_provider->GetKernelRegistry();
|
||||
const KernelCreateInfo* kci;
|
||||
auto st = reg->TryFindKernel(node, execution_provider->Type(), &kci);
|
||||
|
|
|
|||
|
|
@ -663,14 +663,6 @@ TEST(GradientCheckerTest, ReluGrad) {
|
|||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
#ifdef USE_DNNL
|
||||
TEST(GradientCheckerTest, ReluGradDnnl) {
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
execution_providers.push_back(DefaultDnnlExecutionProvider());
|
||||
UnaryOpGradientTest("Relu", kOnnxDomain, 9, &execution_providers);
|
||||
}
|
||||
#endif // USE_DNNL
|
||||
|
||||
TEST(GradientCheckerTest, CastGrad) {
|
||||
// A dummy test that cast float to float
|
||||
// TODO: add more test here
|
||||
|
|
@ -728,7 +720,7 @@ static std::vector<std::vector<T>> GetRandomValuesForMaxPool(const std::vector<T
|
|||
return datas;
|
||||
}
|
||||
|
||||
void MaxpoolGradientCheckerTest(std::vector<std::unique_ptr<IExecutionProvider>>* execution_provider) {
|
||||
TEST(GradientCheckerTest, MaxPoolGrad) {
|
||||
float max_error;
|
||||
GradientChecker<float, float, float> gradient_checker;
|
||||
OpDef op_def{"MaxPool"};
|
||||
|
|
@ -737,9 +729,7 @@ void MaxpoolGradientCheckerTest(std::vector<std::unique_ptr<IExecutionProvider>>
|
|||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{2, 2, 9}}, {{2, 2, 8}}, &max_error,
|
||||
GetRandomValuesForMaxPool<float>({{2, 2, 9}}),
|
||||
{MakeAttribute("kernel_shape", std::vector<int64_t>{2})},
|
||||
true, false,
|
||||
execution_provider);
|
||||
{MakeAttribute("kernel_shape", std::vector<int64_t>{2})});
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
|
|
@ -748,9 +738,7 @@ void MaxpoolGradientCheckerTest(std::vector<std::unique_ptr<IExecutionProvider>>
|
|||
gradient_checker.ComputeGradientError(op_def, {{2, 3, 5, 5}}, {{2, 3, 4, 4}}, &max_error,
|
||||
GetRandomValuesForMaxPool<float>({{2, 3, 5, 5}}),
|
||||
{MakeAttribute("kernel_shape", std::vector<int64_t>{2, 2}),
|
||||
MakeAttribute("strides", std::vector<int64_t>{1, 1})},
|
||||
true, false,
|
||||
execution_provider);
|
||||
MakeAttribute("strides", std::vector<int64_t>{1, 1})});
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
|
|
@ -759,9 +747,7 @@ void MaxpoolGradientCheckerTest(std::vector<std::unique_ptr<IExecutionProvider>>
|
|||
gradient_checker.ComputeGradientError(op_def, {{1, 1, 5, 5}}, {{1, 1, 7, 7}}, &max_error,
|
||||
GetRandomValuesForMaxPool<float>({{1, 1, 5, 5}}),
|
||||
{MakeAttribute("kernel_shape", std::vector<int64_t>{3, 3}),
|
||||
MakeAttribute("pads", std::vector<int64_t>{2, 2, 2, 2})},
|
||||
true, false,
|
||||
execution_provider);
|
||||
MakeAttribute("pads", std::vector<int64_t>{2, 2, 2, 2})});
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
|
|
@ -770,9 +756,7 @@ void MaxpoolGradientCheckerTest(std::vector<std::unique_ptr<IExecutionProvider>>
|
|||
gradient_checker.ComputeGradientError(op_def, {{1, 1, 32, 32}}, {{1, 1, 10, 10}}, &max_error,
|
||||
GetRandomValuesForMaxPool<float>({{1, 1, 32, 32}}),
|
||||
{MakeAttribute("kernel_shape", std::vector<int64_t>{5, 5}),
|
||||
MakeAttribute("strides", std::vector<int64_t>{3, 3})},
|
||||
true, false,
|
||||
execution_provider);
|
||||
MakeAttribute("strides", std::vector<int64_t>{3, 3})});
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
|
|
@ -780,23 +764,11 @@ void MaxpoolGradientCheckerTest(std::vector<std::unique_ptr<IExecutionProvider>>
|
|||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{2, 1, 3, 3, 3}}, {{2, 1, 2, 2, 2}}, &max_error,
|
||||
GetRandomValuesForMaxPool<float>({{2, 1, 3, 3, 3}}),
|
||||
{MakeAttribute("kernel_shape", std::vector<int64_t>{2, 2, 2})},
|
||||
true, false,
|
||||
execution_provider);
|
||||
{MakeAttribute("kernel_shape", std::vector<int64_t>{2, 2, 2})});
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, MaxPoolGrad) {
|
||||
MaxpoolGradientCheckerTest(nullptr);
|
||||
|
||||
#ifdef USE_DNNL
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
execution_providers.push_back(DefaultDnnlExecutionProvider());
|
||||
MaxpoolGradientCheckerTest(&execution_providers);
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, GlobalAveragePoolGrad) {
|
||||
float max_error;
|
||||
GradientChecker<float, float, float> gradient_checker;
|
||||
|
|
@ -1057,15 +1029,17 @@ void ConvGradientCheckerTest(std::vector<std::unique_ptr<IExecutionProvider>>* e
|
|||
|
||||
TEST(GradientCheckerTest, ConvGrad) {
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
#ifdef USE_DNNL
|
||||
// Dnnl EP does not run for ConvGrad unless it is pushed first.
|
||||
execution_providers.push_back(DefaultDnnlExecutionProvider());
|
||||
#endif
|
||||
|
||||
execution_providers.push_back(DefaultCpuExecutionProvider());
|
||||
|
||||
if (HasCudaEnvironment(700)) {
|
||||
execution_providers.push_back(DefaultCudaExecutionProvider());
|
||||
}
|
||||
|
||||
#ifdef USE_DNNL
|
||||
execution_providers.push_back(DefaultDnnlExecutionProvider());
|
||||
#endif
|
||||
|
||||
ConvGradientCheckerTest(&execution_providers);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue