mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
Add elementwise operators to DNNL execution provider (#8899)
The following ops have been added to the DNNL execution provider Abs, Elu, Exp, Log, *Relu, Round, Sigmoid, Softplus, Sqrt, and Tanh *Relu op was moved from its individual file to the elementwise operators The error tolerance for the LogGrad unit test had to be decreased slightly when using OneDNN. Still investigating why a differet tolerance value is needed. DnnlSubgraph::AddKernels() member function was moved to the top of the file since this is eddited every time a new operator is added to the the execution provider this places the code at the top which mean less scrooling when adding new kernels. Signed-off-by: George Nash <george.nash@intel.com>
This commit is contained in:
parent
2e37fe3f68
commit
dc75a135c8
9 changed files with 213 additions and 105 deletions
|
|
@ -372,6 +372,30 @@ bool DnnlBinaryNodeCapability::IsDimensionSupported(const Node* node) const {
|
|||
return true;
|
||||
}
|
||||
|
||||
// DnnlBinaryNodeCapability class
|
||||
//-------------------------------------
|
||||
bool DnnlElementwiseCapability::Supported(const Node* node) const {
|
||||
if (!IsTypeSupported(node)) return false;
|
||||
if (!IsDimensionSupported(node)) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool DnnlElementwiseCapability::IsDimensionSupported(const Node* node) const {
|
||||
auto node_inputs = node->InputDefs();
|
||||
if (node_inputs[0]->Shape() == nullptr) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// OneDNN will silently convert scaler values to a {1} tensor which causes issues for
|
||||
// for Onnruntime when it expects an empty tensor i.e. {}
|
||||
// TODO convert {1} outputs back to scaler {} once that is done DnnlElementwiseCapability
|
||||
// can be removed and just us the DnnlDefaultNodeCapability.
|
||||
if (node_inputs[0]->Shape() != nullptr && node_inputs[0]->Shape()->dim_size() == 0) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// DnnlGemmNodeCapability class
|
||||
//-------------------------------------
|
||||
bool DnnlGemmNodeCapability::Supported(const Node* node) const {
|
||||
|
|
|
|||
|
|
@ -207,6 +207,16 @@ class DnnlBinaryNodeCapability : public DnnlDefaultNodeCapability {
|
|||
bool IsDimensionSupported(const Node* node) const;
|
||||
};
|
||||
|
||||
class DnnlElementwiseCapability : public DnnlDefaultNodeCapability {
|
||||
public:
|
||||
DnnlElementwiseCapability() : DnnlDefaultNodeCapability({"float"}) {}
|
||||
|
||||
bool Supported(const Node* node) const override;
|
||||
|
||||
private:
|
||||
bool IsDimensionSupported(const Node* node) const;
|
||||
};
|
||||
|
||||
/**
|
||||
* Decide if a Gemm op is supported by DnnlExecutionProvider
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -6,24 +6,33 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
DnnlOpManager::DnnlOpManager() {
|
||||
dnnl_ops_map_.emplace(std::make_pair("Abs", std::unique_ptr<DnnlNodeCapability>(new DnnlElementwiseCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("Add", std::unique_ptr<DnnlNodeCapability>(new DnnlBinaryNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("AveragePool", std::unique_ptr<DnnlNodeCapability>(new DnnlPoolNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("BatchNormalization", std::unique_ptr<DnnlNodeCapability>(new DnnlBatchNormalizationNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("Conv", std::unique_ptr<DnnlNodeCapability>(new DnnlDefaultNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("Div", std::unique_ptr<DnnlNodeCapability>(new DnnlBinaryNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("Elu", std::unique_ptr<DnnlNodeCapability>(new DnnlElementwiseCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("Exp", std::unique_ptr<DnnlNodeCapability>(new DnnlElementwiseCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("Gemm", std::unique_ptr<DnnlNodeCapability>(new DnnlGemmNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("GlobalAveragePool", std::unique_ptr<DnnlNodeCapability>(new DnnlPoolNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("GlobalMaxPool", std::unique_ptr<DnnlNodeCapability>(new DnnlPoolNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("Log", std::unique_ptr<DnnlNodeCapability>(new DnnlElementwiseCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("LRN", std::unique_ptr<DnnlNodeCapability>(new DnnlDefaultNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("MatMul", std::unique_ptr<DnnlNodeCapability>(new DnnlMatMulNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("MatMulInteger", std::unique_ptr<DnnlNodeCapability>(new DnnlMatMulIntegerNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("MaxPool", std::unique_ptr<DnnlNodeCapability>(new DnnlPoolNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("Mul", std::unique_ptr<DnnlNodeCapability>(new DnnlBinaryNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("ReduceMean", std::unique_ptr<DnnlNodeCapability>(new DnnlReduceMeanNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("Relu", std::unique_ptr<DnnlNodeCapability>(new DnnlDefaultNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("Relu", std::unique_ptr<DnnlNodeCapability>(new DnnlElementwiseCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("Round", std::unique_ptr<DnnlNodeCapability>(new DnnlElementwiseCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("Sigmoid", std::unique_ptr<DnnlNodeCapability>(new DnnlElementwiseCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("Softmax", std::unique_ptr<DnnlNodeCapability>(new DnnlSoftmaxNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("Softplus", std::unique_ptr<DnnlNodeCapability>(new DnnlElementwiseCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("Sqrt", std::unique_ptr<DnnlNodeCapability>(new DnnlElementwiseCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("Sub", std::unique_ptr<DnnlNodeCapability>(new DnnlBinaryNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("Sum", std::unique_ptr<DnnlNodeCapability>(new DnnlSumNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("Tanh", std::unique_ptr<DnnlNodeCapability>(new DnnlElementwiseCapability())));
|
||||
#if defined(ENABLE_TRAINING)
|
||||
dnnl_ops_map_.emplace(std::make_pair("AveragePoolGrad", std::unique_ptr<DnnlNodeCapability>(new DnnlPoolNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("ConvGrad", std::unique_ptr<DnnlNodeCapability>(new DnnlDefaultNodeCapability())));
|
||||
|
|
|
|||
80
onnxruntime/core/providers/dnnl/subgraph/dnnl_elementwise.cc
Normal file
80
onnxruntime/core/providers/dnnl/subgraph/dnnl_elementwise.cc
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
// Copyright(C) 2021 Intel Corporation
|
||||
// Licensed under the MIT License
|
||||
|
||||
#include "dnnl_elementwise.h"
|
||||
#include "dnnl_subgraph.h"
|
||||
#include "dnnl_subgraph_primitive.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace ort_dnnl {
|
||||
|
||||
DnnlElementwise::DnnlElementwise() {
|
||||
default_alpha_ = 1.0;
|
||||
}
|
||||
|
||||
void DnnlElementwise::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) {
|
||||
auto dnnl_engine = sp.GetEngine();
|
||||
|
||||
auto elementwise_src_mem = sp.GetMemory(node.Input(IN_X).Name());
|
||||
auto src_md = elementwise_src_mem.get_desc();
|
||||
dnnl::algorithm algo;
|
||||
bool requires_alpha = false;
|
||||
float alpha = 0.0;
|
||||
if (node.OpType() == "Abs") {
|
||||
algo = dnnl::algorithm::eltwise_abs;
|
||||
} else if (node.OpType() == "Elu") {
|
||||
requires_alpha = true;
|
||||
default_alpha_ = 1.0;
|
||||
alpha = GetAlpha(node);
|
||||
algo = dnnl::algorithm::eltwise_elu;
|
||||
} else if (node.OpType() == "Exp") {
|
||||
algo = dnnl::algorithm::eltwise_exp;
|
||||
} else if (node.OpType() == "Log") {
|
||||
algo = dnnl::algorithm::eltwise_log;
|
||||
} else if (node.OpType() == "Relu") {
|
||||
algo = dnnl::algorithm::eltwise_relu;
|
||||
} else if (node.OpType() == "Round") {
|
||||
algo = dnnl::algorithm::eltwise_round;
|
||||
} else if (node.OpType() == "Sigmoid") {
|
||||
// in OneDNN eltwise_logistic is defined as 1/(1 + exp(-x)) which matches the definition of "Sigmoid" in Onnx
|
||||
algo = dnnl::algorithm::eltwise_logistic;
|
||||
} else if (node.OpType() == "Softplus") {
|
||||
// in OneDNN eltwise_soft_relu is defined as ln(1 + exp(x)) which matches the definition of "Softplus" in Onnx
|
||||
algo = dnnl::algorithm::eltwise_soft_relu;
|
||||
} else if (node.OpType() == "Sqrt") {
|
||||
algo = dnnl::algorithm::eltwise_sqrt;
|
||||
} else if (node.OpType() == "Tanh") {
|
||||
algo = dnnl::algorithm::eltwise_tanh;
|
||||
} else {
|
||||
ORT_THROW("op type not supported");
|
||||
}
|
||||
dnnl::eltwise_forward::primitive_desc elementwise_pd;
|
||||
if (requires_alpha) {
|
||||
auto elementwise_desc = dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo, src_md, alpha);
|
||||
elementwise_pd = dnnl::eltwise_forward::primitive_desc(elementwise_desc, dnnl_engine);
|
||||
} else {
|
||||
auto elementwise_desc = dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo, src_md);
|
||||
elementwise_pd = dnnl::eltwise_forward::primitive_desc(elementwise_desc, dnnl_engine);
|
||||
}
|
||||
|
||||
// If using GPU this will move the memory from the CPU to the GPU.
|
||||
elementwise_src_mem = sp.GetMemoryAndReshape(node.Input(IN_X), elementwise_pd.src_desc(), dnnl_engine);
|
||||
auto elementwise_dst_mem = dnnl::memory(elementwise_pd.dst_desc(), dnnl_engine);
|
||||
|
||||
auto elemenwise_primitive = dnnl::eltwise_forward(elementwise_pd);
|
||||
sp.AddPrimitive(elemenwise_primitive, {{DNNL_ARG_SRC, elementwise_src_mem},
|
||||
{DNNL_ARG_DST, elementwise_dst_mem}});
|
||||
|
||||
sp.SetMemory(node.Output(OUT_Y), elementwise_dst_mem);
|
||||
}
|
||||
|
||||
float DnnlElementwise::GetAlpha(DnnlNode& node) {
|
||||
auto attr = node.Attributes().find("alpha");
|
||||
if (attr != node.Attributes().end()) {
|
||||
return attr->second().f();
|
||||
}
|
||||
return default_alpha_;
|
||||
}
|
||||
|
||||
} // namespace ort_dnnl
|
||||
} // namespace onnxruntime
|
||||
41
onnxruntime/core/providers/dnnl/subgraph/dnnl_elementwise.h
Normal file
41
onnxruntime/core/providers/dnnl/subgraph/dnnl_elementwise.h
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
// Copyright(C) 2021 Intel Corporation
|
||||
// Licensed under the MIT License
|
||||
|
||||
#pragma once
|
||||
#include "dnnl_subgraph.h"
|
||||
#include "dnnl_subgraph_primitive.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace ort_dnnl {
|
||||
|
||||
class DnnlElementwise {
|
||||
public:
|
||||
enum InputTensors : int {
|
||||
IN_X = 0,
|
||||
};
|
||||
|
||||
enum OutputTensors : int {
|
||||
OUT_Y = 0
|
||||
};
|
||||
|
||||
DnnlElementwise();
|
||||
void CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node);
|
||||
|
||||
private:
|
||||
/*
|
||||
* GetAlpha will get the 'alpha' attribute if the attribute is not found
|
||||
* the the `default_alph_` will be returned instead. This is set to 1.0
|
||||
* by the `DnnlElementwise` constructor but should be updated for any operator
|
||||
* that has an 'alpha' property.
|
||||
*
|
||||
* See how `GetAlpha` is called for the 'Elu' operator in the `CreatePrimitive` code.
|
||||
*
|
||||
* Note: The number of operators that use the 'alpha' attribute is much smaller than
|
||||
* initially expected.
|
||||
*/
|
||||
float GetAlpha(DnnlNode& node);
|
||||
float default_alpha_;
|
||||
};
|
||||
|
||||
} // namespace ort_dnnl
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1,35 +0,0 @@
|
|||
// Copyright(C) 2021 Intel Corporation
|
||||
// Licensed under the MIT License
|
||||
|
||||
#include "dnnl_relu.h"
|
||||
#include "dnnl_subgraph.h"
|
||||
#include "dnnl_subgraph_primitive.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace ort_dnnl {
|
||||
|
||||
DnnlRelu::DnnlRelu() {}
|
||||
|
||||
void DnnlRelu::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) {
|
||||
|
||||
auto dnnl_engine = sp.GetEngine();
|
||||
|
||||
auto relu_src_mem = sp.GetMemory(node.Input(IN_X).Name());
|
||||
auto src_md = relu_src_mem.get_desc();
|
||||
|
||||
auto relu_desc = dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::eltwise_relu, src_md);
|
||||
auto relu_pd = dnnl::eltwise_forward::primitive_desc(relu_desc, dnnl_engine);
|
||||
|
||||
// If using GPU this will move the memory from the CPU to the GPU.
|
||||
relu_src_mem = sp.GetMemoryAndReshape(node.Input(IN_X), relu_pd.src_desc(), dnnl_engine);
|
||||
auto relu_dst_mem = dnnl::memory(relu_pd.dst_desc(), dnnl_engine);
|
||||
|
||||
auto relu_op = dnnl::eltwise_forward(relu_pd);
|
||||
sp.AddPrimitive(relu_op, {{DNNL_ARG_SRC, relu_src_mem},
|
||||
{DNNL_ARG_DST, relu_dst_mem}});
|
||||
|
||||
sp.SetMemory(node.Output(OUT_Y), relu_dst_mem);
|
||||
}
|
||||
|
||||
} // namespace ort_dnnl
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
// Copyright(C) 2021 Intel Corporation
|
||||
// Licensed under the MIT License
|
||||
|
||||
#pragma once
|
||||
#include "dnnl_subgraph.h"
|
||||
#include "dnnl_subgraph_primitive.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace ort_dnnl {
|
||||
|
||||
class DnnlRelu {
|
||||
public:
|
||||
enum InputTensors : int {
|
||||
IN_X = 0,
|
||||
};
|
||||
|
||||
enum OutputTensors : int {
|
||||
OUT_Y = 0
|
||||
};
|
||||
|
||||
DnnlRelu();
|
||||
void CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node);
|
||||
};
|
||||
|
||||
} // namespace ort_dnnl
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -6,13 +6,13 @@
|
|||
#include "dnnl_batchnorm.h"
|
||||
#include "dnnl_binary.h"
|
||||
#include "dnnl_conv.h"
|
||||
#include "dnnl_elementwise.h"
|
||||
#include "dnnl_gemm.h"
|
||||
#include "dnnl_lrn.h"
|
||||
#include "dnnl_matmul.h"
|
||||
#include "dnnl_matmul_integer.h"
|
||||
#include "dnnl_pool.h"
|
||||
#include "dnnl_reducemean.h"
|
||||
#include "dnnl_relu.h"
|
||||
#include "dnnl_softmax.h"
|
||||
#include "dnnl_sum.h"
|
||||
|
||||
|
|
@ -37,6 +37,49 @@ int Product(dnnl::memory::dims d) {
|
|||
return result;
|
||||
}
|
||||
|
||||
void DnnlSubgraphPrimitive::AddKernels() {
|
||||
std::unordered_set<std::string> binary_ops = {"Add", "Div", "Mul", "Sub"};
|
||||
std::unordered_set<std::string> elementwise_ops = {"Abs", "Elu", "Exp","Log", "Relu", "Round", "Sigmoid", "Softplus", "Sqrt", "Tanh"};
|
||||
std::unordered_set<std::string> pool_ops = {"AveragePool", "GlobalAveragePool", "GlobalMaxPool", "MaxPool"};
|
||||
for (auto& node : subgraph_->GetDnnlNodes()) {
|
||||
if (node.OpType() == "BatchNormalization") {
|
||||
DnnlBatchNorm().CreatePrimitive(*this, node);
|
||||
} else if (binary_ops.count(node.OpType())) {
|
||||
DnnlBinary().CreatePrimitive(*this, node);
|
||||
} else if (node.OpType() == "Conv") {
|
||||
DnnlConv().CreatePrimitive(*this, node);
|
||||
} else if (elementwise_ops.count(node.OpType())) {
|
||||
DnnlElementwise().CreatePrimitive(*this, node);
|
||||
} else if (node.OpType() == "Gemm") {
|
||||
DnnlGemm().CreatePrimitive(*this, node);
|
||||
} else if (node.OpType() == "LRN") {
|
||||
DnnlLrn().CreatePrimitive(*this, node);
|
||||
} else if (node.OpType() == "MatMul") {
|
||||
DnnlMatMul().CreatePrimitive(*this, node);
|
||||
} else if (node.OpType() == "MatMulInteger") {
|
||||
DnnlMatMulInteger().CreatePrimitive(*this, node);
|
||||
} else if (pool_ops.count(node.OpType())) {
|
||||
DnnlPool().CreatePrimitive(*this, node);
|
||||
} else if (node.OpType() == "ReduceMean") {
|
||||
DnnlReduceMean().CreatePrimitive(*this, node);
|
||||
} else if (node.OpType() == "Softmax") {
|
||||
DnnlSoftmax().CreatePrimitive(*this, node);
|
||||
} else if (node.OpType() == "Sum") {
|
||||
DnnlSum().CreatePrimitive(*this, node);
|
||||
#if defined(ENABLE_TRAINING)
|
||||
} else if (node.OpType() == "AveragePoolGrad" || node.OpType() == "MaxPoolGrad") {
|
||||
DnnlPoolGrad().CreatePrimitive(*this, node);
|
||||
} else if (node.OpType() == "ConvGrad") {
|
||||
DnnlConvGrad().CreatePrimitive(*this, node);
|
||||
} else if (node.OpType() == "ReluGrad") {
|
||||
DnnlReluGrad().CreatePrimitive(*this, node);
|
||||
#endif
|
||||
} else {
|
||||
throw std::invalid_argument("Kernel not found");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
DnnlSubgraphPrimitive::DnnlSubgraphPrimitive(ort_dnnl::DnnlSubgraph& dnnl_subgraph) {
|
||||
subgraph_ = &dnnl_subgraph;
|
||||
if (dnnl_engine_get_count(dnnl_engine_kind_t::dnnl_cpu)) {
|
||||
|
|
@ -194,48 +237,6 @@ void DnnlSubgraphPrimitive::AddInitializers() {
|
|||
}
|
||||
}
|
||||
|
||||
void DnnlSubgraphPrimitive::AddKernels() {
|
||||
std::unordered_set<std::string> binary_ops = {"Add", "Mul", "Sub", "Div"};
|
||||
for (auto& node : subgraph_->GetDnnlNodes()) {
|
||||
if (node.OpType() == "AveragePool" || node.OpType() == "GlobalAveragePool" ||
|
||||
node.OpType() == "GlobalMaxPool" || node.OpType() == "MaxPool") {
|
||||
DnnlPool().CreatePrimitive(*this, node);
|
||||
} else if (binary_ops.count(node.OpType())) {
|
||||
DnnlBinary().CreatePrimitive(*this, node);
|
||||
} else if (node.OpType() == "BatchNormalization") {
|
||||
DnnlBatchNorm().CreatePrimitive(*this, node);
|
||||
} else if (node.OpType() == "Conv") {
|
||||
DnnlConv().CreatePrimitive(*this, node);
|
||||
} else if (node.OpType() == "Gemm") {
|
||||
DnnlGemm().CreatePrimitive(*this, node);
|
||||
} else if (node.OpType() == "LRN") {
|
||||
DnnlLrn().CreatePrimitive(*this, node);
|
||||
} else if (node.OpType() == "MatMul") {
|
||||
DnnlMatMul().CreatePrimitive(*this, node);
|
||||
} else if (node.OpType() == "MatMulInteger") {
|
||||
DnnlMatMulInteger().CreatePrimitive(*this, node);
|
||||
} else if (node.OpType() == "ReduceMean") {
|
||||
DnnlReduceMean().CreatePrimitive(*this, node);
|
||||
} else if (node.OpType() == "Relu") {
|
||||
DnnlRelu().CreatePrimitive(*this, node);
|
||||
} else if (node.OpType() == "Softmax") {
|
||||
DnnlSoftmax().CreatePrimitive(*this, node);
|
||||
} else if (node.OpType() == "Sum") {
|
||||
DnnlSum().CreatePrimitive(*this, node);
|
||||
#if defined(ENABLE_TRAINING)
|
||||
} else if (node.OpType() == "AveragePoolGrad" || node.OpType() == "MaxPoolGrad") {
|
||||
DnnlPoolGrad().CreatePrimitive(*this, node);
|
||||
} else if (node.OpType() == "ConvGrad") {
|
||||
DnnlConvGrad().CreatePrimitive(*this, node);
|
||||
} else if (node.OpType() == "ReluGrad") {
|
||||
DnnlReluGrad().CreatePrimitive(*this, node);
|
||||
#endif
|
||||
} else {
|
||||
throw std::invalid_argument("Kernel not found");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DnnlSubgraphPrimitive::AddOutputs() {
|
||||
for (auto& tensor : subgraph_->GetDnnlOutputs()) {
|
||||
auto dnnl_data_type = tensor.Type();
|
||||
|
|
|
|||
|
|
@ -441,7 +441,11 @@ TEST(GradientCheckerTest, LogGrad) {
|
|||
TensorInfo x_info{shape, true, &transformer};
|
||||
|
||||
float max_error;
|
||||
#ifdef USE_DNNL
|
||||
float error_tolerance = 3e-3f;
|
||||
#else
|
||||
float error_tolerance = 1e-3f;
|
||||
#endif
|
||||
GradientChecker<float, float, float> gradient_checker;
|
||||
OpDef op_def{"Log"};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue