From 71a70ecf6e57ae5232d9ead899f9d1cfacd9554b Mon Sep 17 00:00:00 2001 From: fthielke Date: Wed, 24 Feb 2021 20:53:31 +0100 Subject: [PATCH] Allow 3D ConvTranspose in CUDA execution provider (#6794) Co-authored-by: Felix Thielke --- onnxruntime/core/providers/cuda/nn/conv_transpose.cc | 5 ++--- onnxruntime/test/onnx/main.cc | 1 - 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc index 8a795b5d90..c71a9b4453 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc @@ -45,10 +45,9 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ auto x_data = reinterpret_cast(X->template Data()); auto x_dimensions = X->Shape().NumDimensions(); - if (x_dimensions != 4 && x_dimensions != 3) { - // This condition is not true for test_convtranspose_3d in ONNX tests series. + if (x_dimensions < 3 || x_dimensions > 5) { // TODO: the error message should tell which operator raises it. - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input X must be 3- or 4-dimensional.", + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input X must be 3-, 4- or 5-dimensional.", " X: ", X->Shape().ToString().c_str()); } const Tensor* W = context->Input(1); diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 4456a9af4c..474c5a1a27 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -793,7 +793,6 @@ int real_main(int argc, char* argv[], Ort::Env& env) { broken_tests.insert({"tf_inception_v1", "flaky test"}); //TODO: Investigate cause for flakiness broken_tests.insert({"faster_rcnn", "Linux: faster_rcnn:output=6383:shape mismatch, expect {77} got {57}"}); broken_tests.insert({"split_zero_size_splits", "alloc failed"}); - broken_tests.insert({"convtranspose_3d", "3d convtranspose not supported yet"}); } if (enable_dml) {