mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
mkldnn conv 1d, 3d support. (#130)
This commit is contained in:
parent
a09a3d3aa5
commit
ed5abc8b94
2 changed files with 102 additions and 77 deletions
|
|
@ -13,21 +13,21 @@ namespace onnxruntime {
|
|||
namespace mkl_dnn {
|
||||
|
||||
namespace {
|
||||
// Struct which encapsulates parameters for MKLDNN Conv2d primitive.
|
||||
struct Conv2dParams {
|
||||
mkldnn::memory::dims& src_dims;
|
||||
mkldnn::memory::dims& filter_dims;
|
||||
mkldnn::memory::dims& bias_dims;
|
||||
mkldnn::memory::dims& dst_dims;
|
||||
mkldnn::memory::dims& strides;
|
||||
mkldnn::memory::dims& dilations;
|
||||
mkldnn::memory::dims& padding_left;
|
||||
mkldnn::memory::dims& padding_right;
|
||||
// Struct which encapsulates parameters for MKLDNN Conv primitive.
|
||||
struct ConvParams {
|
||||
const mkldnn::memory::dims& src_dims;
|
||||
const mkldnn::memory::dims& filter_dims;
|
||||
const mkldnn::memory::dims& bias_dims;
|
||||
const mkldnn::memory::dims& dst_dims;
|
||||
const mkldnn::memory::dims& strides;
|
||||
const mkldnn::memory::dims& dilations;
|
||||
const mkldnn::memory::dims& padding_left;
|
||||
const mkldnn::memory::dims& padding_right;
|
||||
|
||||
Conv2dParams(mkldnn::memory::dims& src_dims, mkldnn::memory::dims& filter_dims,
|
||||
mkldnn::memory::dims& bias_dims, mkldnn::memory::dims& dst_dims,
|
||||
mkldnn::memory::dims& strides, mkldnn::memory::dims& dilations,
|
||||
mkldnn::memory::dims& padding_left, mkldnn::memory::dims& padding_right)
|
||||
ConvParams(const mkldnn::memory::dims& src_dims, const mkldnn::memory::dims& filter_dims,
|
||||
const mkldnn::memory::dims& bias_dims, const mkldnn::memory::dims& dst_dims,
|
||||
const mkldnn::memory::dims& strides, const mkldnn::memory::dims& dilations,
|
||||
const mkldnn::memory::dims& padding_left, const mkldnn::memory::dims& padding_right)
|
||||
: src_dims(src_dims),
|
||||
filter_dims(filter_dims),
|
||||
bias_dims(bias_dims),
|
||||
|
|
@ -37,11 +37,11 @@ struct Conv2dParams {
|
|||
padding_left(padding_left),
|
||||
padding_right(padding_right) {}
|
||||
|
||||
// Used as the key for Conv2d Primitive Reuse Pool.
|
||||
// Used as the key for Conv Primitive Reuse Pool.
|
||||
std::string ToString() const {
|
||||
std::string key;
|
||||
key.reserve(128);
|
||||
key.append("conv2d_");
|
||||
key.append("conv_");
|
||||
AddDimsToKey(key, src_dims);
|
||||
AddDimsToKey(key, filter_dims);
|
||||
AddDimsToKey(key, bias_dims);
|
||||
|
|
@ -55,9 +55,9 @@ struct Conv2dParams {
|
|||
};
|
||||
|
||||
template <typename T>
|
||||
class Conv2dPrimitive : public PrimitiveBase {
|
||||
class ConvPrimitive : public PrimitiveBase {
|
||||
public:
|
||||
explicit Conv2dPrimitive(const Conv2dParams& params)
|
||||
explicit ConvPrimitive(const ConvParams& params)
|
||||
: cpu_engine_(GetEngine()) {
|
||||
context_.stream.reset(new mkldnn::stream(mkldnn::stream::kind::eager));
|
||||
if (context_.conv_fwd == nullptr) {
|
||||
|
|
@ -65,7 +65,7 @@ class Conv2dPrimitive : public PrimitiveBase {
|
|||
}
|
||||
}
|
||||
|
||||
~Conv2dPrimitive() = default;
|
||||
~ConvPrimitive() = default;
|
||||
|
||||
void Compute(const T* src_data, const T* filter_data,
|
||||
const T* dst_data, const T* bias_data = nullptr) {
|
||||
|
|
@ -107,7 +107,7 @@ class Conv2dPrimitive : public PrimitiveBase {
|
|||
}
|
||||
|
||||
private:
|
||||
struct Conv2dContext {
|
||||
struct ConvContext {
|
||||
mkldnn::memory::format src_fmt;
|
||||
mkldnn::memory::format filter_fmt;
|
||||
mkldnn::memory::format dst_fmt;
|
||||
|
|
@ -134,7 +134,7 @@ class Conv2dPrimitive : public PrimitiveBase {
|
|||
std::unique_ptr<mkldnn::stream> stream;
|
||||
std::vector<mkldnn::primitive> net;
|
||||
|
||||
Conv2dContext()
|
||||
ConvContext()
|
||||
: src_fmt(mkldnn::memory::format::any),
|
||||
filter_fmt(mkldnn::memory::format::any),
|
||||
dst_fmt(mkldnn::memory::format::any),
|
||||
|
|
@ -154,7 +154,7 @@ class Conv2dPrimitive : public PrimitiveBase {
|
|||
stream(nullptr) {}
|
||||
};
|
||||
|
||||
void Initialize(const Conv2dParams& params) {
|
||||
void Initialize(const ConvParams& params) {
|
||||
// Set the memory descriptors to format::any to allow MKLDNN to decide what the optimal memory layout should be
|
||||
// for the computation given the input params.
|
||||
context_.src_md.reset(new mkldnn::memory::desc(
|
||||
|
|
@ -222,33 +222,33 @@ class Conv2dPrimitive : public PrimitiveBase {
|
|||
context_.net.push_back(*context_.conv_fwd);
|
||||
}
|
||||
|
||||
Conv2dContext context_;
|
||||
ConvContext context_;
|
||||
mkldnn::engine& cpu_engine_;
|
||||
};
|
||||
|
||||
// Pool which allows for reuse of MKLDNN Conv2d primitives which are expensive to instantiate.
|
||||
// Pool which allows for reuse of MKLDNN Conv primitives which are expensive to instantiate.
|
||||
// To address thread safety, the primitives are stored in a map on thread local storage.
|
||||
template <typename T>
|
||||
class Conv2dPrimitivePool : public PrimitivePool<T> {
|
||||
class ConvPrimitivePool : public PrimitivePool<T> {
|
||||
public:
|
||||
static Conv2dPrimitive<T>* Get(const Conv2dParams& params) {
|
||||
Conv2dPrimitive<T>* primitive = dynamic_cast<Conv2dPrimitive<T>*>(
|
||||
Conv2dPrimitivePool<T>::GetInstance().GetPrimitive(params.ToString()));
|
||||
static ConvPrimitive<T>* Get(const ConvParams& params) {
|
||||
ConvPrimitive<T>* primitive = dynamic_cast<ConvPrimitive<T>*>(
|
||||
ConvPrimitivePool<T>::GetInstance().GetPrimitive(params.ToString()));
|
||||
|
||||
if (primitive == nullptr) {
|
||||
auto conv2d_primitive = std::make_unique<Conv2dPrimitive<T>>(params);
|
||||
primitive = conv2d_primitive.get();
|
||||
Conv2dPrimitivePool<T>::GetInstance().SetPrimitive(params.ToString(), std::move(conv2d_primitive));
|
||||
auto conv_primitive = std::make_unique<ConvPrimitive<T>>(params);
|
||||
primitive = conv_primitive.get();
|
||||
ConvPrimitivePool<T>::GetInstance().SetPrimitive(params.ToString(), std::move(conv_primitive));
|
||||
}
|
||||
return primitive;
|
||||
}
|
||||
|
||||
private:
|
||||
Conv2dPrimitivePool() = default;
|
||||
~Conv2dPrimitivePool() = default;
|
||||
ConvPrimitivePool() = default;
|
||||
~ConvPrimitivePool() = default;
|
||||
|
||||
static Conv2dPrimitivePool& GetInstance() {
|
||||
static Conv2dPrimitivePool pool;
|
||||
static ConvPrimitivePool& GetInstance() {
|
||||
static ConvPrimitivePool pool;
|
||||
return pool;
|
||||
}
|
||||
};
|
||||
|
|
@ -268,20 +268,20 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
|
|||
ONNXRUNTIME_RETURN_IF_ERROR(onnxruntime::ConvBase::ValidateInputShape(X, W));
|
||||
|
||||
std::vector<int64_t> kernel_shape = onnxruntime::ConvBase::ComputeKernelShape(W->Shape());
|
||||
const size_t kernel_rank = kernel_shape.size();
|
||||
|
||||
// TODO: Support more than 2d kernels
|
||||
if (kernel_shape.size() != 2) {
|
||||
if (kernel_rank > 3) {
|
||||
// Fall Back to CPU implementation.
|
||||
return onnxruntime::Conv<T>::Compute(context);
|
||||
}
|
||||
|
||||
if (kernel_shape.size() + 2 != W->Shape().NumDimensions()) {
|
||||
if (kernel_rank + 2 != W->Shape().NumDimensions()) {
|
||||
return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "kernel_shape num_dims is not compatible with W num_dims.",
|
||||
" kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(),
|
||||
" W: ", W->Shape().ToString().c_str());
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < kernel_shape.size(); ++i) {
|
||||
for (size_t i = 0; i < kernel_rank; ++i) {
|
||||
if (kernel_shape[i] != W->Shape()[i + 2]) {
|
||||
return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "kernel_shape is not compatible with W shape.",
|
||||
" kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(),
|
||||
|
|
@ -291,15 +291,15 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
|
|||
|
||||
std::vector<int64_t> pads(onnxruntime::ConvBase::pads_);
|
||||
if (pads.empty()) {
|
||||
pads.resize(kernel_shape.size() * 2, 0);
|
||||
pads.resize(kernel_rank * 2, 0);
|
||||
}
|
||||
std::vector<int64_t> dilations(onnxruntime::ConvBase::dilations_);
|
||||
if (dilations.empty()) {
|
||||
dilations.resize(kernel_shape.size(), 1);
|
||||
dilations.resize(kernel_rank, 1);
|
||||
}
|
||||
std::vector<int64_t> strides(onnxruntime::ConvBase::strides_);
|
||||
if (strides.empty()) {
|
||||
strides.resize(kernel_shape.size(), 1);
|
||||
strides.resize(kernel_rank, 1);
|
||||
}
|
||||
|
||||
std::vector<int64_t> Y_dims;
|
||||
|
|
@ -314,21 +314,19 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
|
|||
if (group_mkl == 1) {
|
||||
filter_dims_mkl.assign(W->Shape().GetDims().begin(), W->Shape().GetDims().end());
|
||||
} else {
|
||||
filter_dims_mkl.assign({
|
||||
group_mkl,
|
||||
static_cast<int>(W->Shape()[0] / group_mkl),
|
||||
static_cast<int>(W->Shape()[1]),
|
||||
static_cast<int>(W->Shape()[2]),
|
||||
static_cast<int>(W->Shape()[3]),
|
||||
});
|
||||
filter_dims_mkl.assign({group_mkl,
|
||||
static_cast<int>(W->Shape()[0] / group_mkl)});
|
||||
filter_dims_mkl.insert(filter_dims_mkl.end(), W->Shape().GetDims().begin() + 1, W->Shape().GetDims().end());
|
||||
}
|
||||
mkldnn::memory::dims strides_mkl(strides.begin(), strides.end());
|
||||
mkldnn::memory::dims dilations_mkl(dilations.begin(), dilations.end());
|
||||
// mkldnn dilations start from 0 so we need to subtract 1 from each dim.
|
||||
dilations_mkl[0] -= 1;
|
||||
dilations_mkl[1] -= 1;
|
||||
mkldnn::memory::dims padding_left_mkl(pads.begin(), pads.begin() + 2);
|
||||
mkldnn::memory::dims padding_right_mkl(pads.begin() + 2, pads.end());
|
||||
for (size_t dim = 0; dim < kernel_rank; dim++) {
|
||||
dilations_mkl[dim] -= 1;
|
||||
}
|
||||
|
||||
mkldnn::memory::dims padding_left_mkl(pads.begin(), pads.begin() + kernel_rank);
|
||||
mkldnn::memory::dims padding_right_mkl(pads.begin() + kernel_rank, pads.end());
|
||||
mkldnn::memory::dims dst_dims_mkl(Y_dims.begin(), Y_dims.end());
|
||||
mkldnn::memory::dims bias_dims_mkl;
|
||||
if (B != nullptr) {
|
||||
|
|
@ -350,27 +348,54 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
|
|||
}
|
||||
|
||||
try {
|
||||
Conv2dParams conv2d_params(src_dims_mkl, filter_dims_mkl, bias_dims_mkl,
|
||||
dst_dims_mkl, strides_mkl, dilations_mkl,
|
||||
padding_left_mkl, padding_right_mkl);
|
||||
Conv2dPrimitive<T>* conv2d_primitive = Conv2dPrimitivePool<T>::Get(conv2d_params);
|
||||
auto conv_fwd_pd = conv2d_primitive->GetPrimitiveDesc();
|
||||
ConvParams conv_params(src_dims_mkl, filter_dims_mkl, bias_dims_mkl,
|
||||
dst_dims_mkl, strides_mkl, dilations_mkl,
|
||||
padding_left_mkl, padding_right_mkl);
|
||||
ConvPrimitive<T>* conv_primitive = ConvPrimitivePool<T>::Get(conv_params);
|
||||
auto conv_fwd_pd = conv_primitive->GetPrimitiveDesc();
|
||||
|
||||
mkldnn::engine& cpu_engine = GetEngine();
|
||||
|
||||
// Per ONNX spec,
|
||||
// X (src) is NCHW, W (filter) is OIHW/GOIHW, and Y (dst) is NCHW
|
||||
auto src_md = mkldnn::memory::desc(src_dims_mkl, MklDnnType<T>(), mkldnn::memory::format::nchw);
|
||||
auto filter_format = group_mkl == 1 ? mkldnn::memory::format::oihw : mkldnn::memory::format::goihw;
|
||||
auto dst_md = mkldnn::memory::desc(dst_dims_mkl, MklDnnType<T>(), mkldnn::memory::format::nchw);
|
||||
enum mkldnn::memory::format src_format = mkldnn::memory::format::format_undef;
|
||||
enum mkldnn::memory::format filter_format = mkldnn::memory::format::format_undef;
|
||||
enum mkldnn::memory::format dst_format = mkldnn::memory::format::format_undef;
|
||||
|
||||
if (kernel_rank == 1) {
|
||||
src_format = mkldnn::memory::format::ncw;
|
||||
if (group_mkl == 1) {
|
||||
filter_format = mkldnn::memory::format::oiw;
|
||||
} else {
|
||||
filter_format = mkldnn::memory::format::goiw;
|
||||
}
|
||||
dst_format = mkldnn::memory::format::ncw;
|
||||
} else if (kernel_rank == 2) {
|
||||
src_format = mkldnn::memory::format::nchw;
|
||||
if (group_mkl == 1) {
|
||||
filter_format = mkldnn::memory::format::oihw;
|
||||
} else {
|
||||
filter_format = mkldnn::memory::format::goihw;
|
||||
}
|
||||
dst_format = mkldnn::memory::format::nchw;
|
||||
} else {
|
||||
src_format = mkldnn::memory::format::ncdhw;
|
||||
if (group_mkl == 1) {
|
||||
filter_format = mkldnn::memory::format::oidhw;
|
||||
} else {
|
||||
filter_format = mkldnn::memory::format::goidhw;
|
||||
}
|
||||
dst_format = mkldnn::memory::format::ncdhw;
|
||||
}
|
||||
|
||||
auto src_md = mkldnn::memory::desc(src_dims_mkl, MklDnnType<T>(), src_format);
|
||||
auto dst_md = mkldnn::memory::desc(dst_dims_mkl, MklDnnType<T>(), dst_format);
|
||||
|
||||
// Reorder src memory layout if necessary.
|
||||
if (src_md.data.format != conv2d_primitive->GetSrcMemoryFormat()) {
|
||||
if (src_md.data.format != conv_primitive->GetSrcMemoryFormat()) {
|
||||
auto pd = mkldnn::memory::primitive_desc(src_md, cpu_engine);
|
||||
mkldnn::memory src = mkldnn::memory(pd, (void*)src_data);
|
||||
// allocate the size queried from memory primitive desc. it may not match tensor logical size due to
|
||||
// mkldnn using padding to allow use of blocked format.
|
||||
src_reorder_buffer = IAllocator::MakeUniquePtr<void>(alloc, conv2d_primitive->GetSrcSize());
|
||||
src_reorder_buffer = IAllocator::MakeUniquePtr<void>(alloc, conv_primitive->GetSrcSize());
|
||||
mkldnn::memory dst = mkldnn::memory(conv_fwd_pd->src_primitive_desc(), src_reorder_buffer.get());
|
||||
MemoryReorderParams params(src, dst);
|
||||
DoReorder<T>(params);
|
||||
|
|
@ -378,7 +403,7 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
|
|||
}
|
||||
|
||||
// Reorder filter memory layout if necessary.
|
||||
if (filter_format != conv2d_primitive->GetFilterMemoryFormat()) {
|
||||
if (filter_format != conv_primitive->GetFilterMemoryFormat()) {
|
||||
auto pd = mkldnn::memory::primitive_desc(mkldnn::memory::desc(filter_dims_mkl,
|
||||
MklDnnType<T>(),
|
||||
filter_format),
|
||||
|
|
@ -386,7 +411,7 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
|
|||
mkldnn::memory src = mkldnn::memory(pd, (void*)filter_data);
|
||||
// allocate the size queried from memory primitive desc. it may not match tensor logical size due to
|
||||
// mkldnn using padding to allow use of blocked format.
|
||||
filter_reorder_buffer = IAllocator::MakeUniquePtr<void>(alloc, conv2d_primitive->GetFilterSize());
|
||||
filter_reorder_buffer = IAllocator::MakeUniquePtr<void>(alloc, conv_primitive->GetFilterSize());
|
||||
mkldnn::memory dst = mkldnn::memory(conv_fwd_pd->weights_primitive_desc(), filter_reorder_buffer.get());
|
||||
MemoryReorderParams params(src, dst);
|
||||
DoReorder<T>(params);
|
||||
|
|
@ -394,17 +419,17 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
|
|||
}
|
||||
|
||||
// Allocate dst buffer if reorder is necessary
|
||||
if (dst_md.data.format != conv2d_primitive->GetDstMemoryFormat()) {
|
||||
if (dst_md.data.format != conv_primitive->GetDstMemoryFormat()) {
|
||||
// allocate the size queried from memory primitive desc. it may not match tensor logical size due to
|
||||
// mkldnn using padding to allow use of blocked format.
|
||||
dst_reorder_buffer = IAllocator::MakeUniquePtr<void>(alloc, conv2d_primitive->GetDstSize());
|
||||
dst_reorder_buffer = IAllocator::MakeUniquePtr<void>(alloc, conv_primitive->GetDstSize());
|
||||
dst_data = static_cast<T*>(dst_reorder_buffer.get());
|
||||
}
|
||||
|
||||
conv2d_primitive->Compute(src_data, filter_data, dst_data, bias_data);
|
||||
conv_primitive->Compute(src_data, filter_data, dst_data, bias_data);
|
||||
|
||||
// Reorder dst memory layout if necessary
|
||||
if (dst_md.data.format != conv2d_primitive->GetDstMemoryFormat()) {
|
||||
if (dst_md.data.format != conv_primitive->GetDstMemoryFormat()) {
|
||||
mkldnn::memory src = mkldnn::memory(conv_fwd_pd->dst_primitive_desc(), (void*)dst_data);
|
||||
auto pd = mkldnn::memory::primitive_desc(dst_md, cpu_engine);
|
||||
mkldnn::memory dst = mkldnn::memory(pd, Y->template MutableData<T>());
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ TEST(ConvTest, Conv1D_1) {
|
|||
auto expected_vals = {-0.052761781960725784f, 0.11481902748346329f, 0.10833403468132019f, -0.11055534332990646f,
|
||||
-0.012766072526574135f, 0.07113571465015411f, 0.061429332941770554f};
|
||||
|
||||
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); // Conv1d not yet optimized for MKLDNN XP
|
||||
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape);
|
||||
}
|
||||
|
||||
// Conv3
|
||||
|
|
@ -106,7 +106,7 @@ TEST(ConvTest, Conv1D_2) {
|
|||
-0.042245108634233475f, -0.08389100432395935f, -0.2509208619594574f, -0.18825212121009827f,
|
||||
-0.18779152631759644f, -0.11083387583494186f};
|
||||
|
||||
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); // Conv1d not yet optimized for MKLDNN XP
|
||||
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape);
|
||||
}
|
||||
|
||||
// Conv1
|
||||
|
|
@ -137,7 +137,7 @@ TEST(ConvTest, Conv1D_Bias) {
|
|||
vector<int64_t> Y_shape = {2, 1, 4};
|
||||
auto expected_vals = {0.37892162799835205f, 0.4625728130340576f, 0.4934738576412201f, 0.44801419973373413f,
|
||||
0.37892162799835205f, 0.2499445676803589f, 0.31682088971138f, 0.32773756980895996f};
|
||||
TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); // Conv1d not yet optimized for MKLDNN XP
|
||||
TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape);
|
||||
}
|
||||
|
||||
// Conv47
|
||||
|
|
@ -396,7 +396,7 @@ TEST(ConvTest, Conv3D_1) {
|
|||
0.10670476406812668f, -0.054437506943941116f, -0.014473143965005875f,
|
||||
-0.13092079758644104f, 0.10221172869205475f, -0.1479327529668808f,
|
||||
-0.011351631954312325f, -0.10867488384246826f, -0.05184098333120346f};
|
||||
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); // Conv3d not yet optimized for MKLDNN XP
|
||||
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape);
|
||||
}
|
||||
|
||||
// Conv22
|
||||
|
|
@ -437,7 +437,7 @@ TEST(ConvTest, Conv3D_2) {
|
|||
0.0f, 0.09152615070343018f, 0.08054415881633759f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
|
||||
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
|
||||
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
|
||||
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); // Conv3d not yet optimized for MKLDNN XP
|
||||
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape);
|
||||
}
|
||||
|
||||
// Conv23
|
||||
|
|
@ -520,7 +520,7 @@ TEST(ConvTest, Conv3D_Bias) {
|
|||
-0.47542816400527954f, -0.5078460574150085f, -0.4205915927886963f, -0.5584549903869629f,
|
||||
-0.39770257472991943f, -0.45317384600639343f, -0.5598302483558655f, -0.2542789578437805f,
|
||||
-0.5359901785850525f, -0.48090484738349915f, -0.38603779673576355f, -0.4991581439971924f};
|
||||
TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); // Conv3d not yet optimized for MKLDNN XP
|
||||
TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape);
|
||||
}
|
||||
|
||||
TEST(ConvTest, Conv2D_group) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue