Allow 3D ConvTranspose in CUDA execution provider (#6794)

Co-authored-by: Felix Thielke <felix.thielke@mevis.fraunhofer.de>
This commit is contained in:
fthielke 2021-02-24 20:53:31 +01:00 committed by GitHub
parent 5a473216b7
commit 71a70ecf6e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 4 deletions

View file

@ -45,10 +45,9 @@ Status ConvTranspose<T>::DoConvTranspose(OpKernelContext* context, bool dynamic_
auto x_data = reinterpret_cast<const CudaT*>(X->template Data<T>());
auto x_dimensions = X->Shape().NumDimensions();
if (x_dimensions != 4 && x_dimensions != 3) {
// This condition is not true for test_convtranspose_3d in ONNX tests series.
if (x_dimensions < 3 || x_dimensions > 5) {
// TODO: the error message should tell which operator raises it.
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input X must be 3- or 4-dimensional.",
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input X must be 3-, 4- or 5-dimensional.",
" X: ", X->Shape().ToString().c_str());
}
const Tensor* W = context->Input<Tensor>(1);

View file

@ -793,7 +793,6 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
broken_tests.insert({"tf_inception_v1", "flaky test"}); //TODO: Investigate cause for flakiness
broken_tests.insert({"faster_rcnn", "Linux: faster_rcnn:output=6383:shape mismatch, expect {77} got {57}"});
broken_tests.insert({"split_zero_size_splits", "alloc failed"});
broken_tests.insert({"convtranspose_3d", "3d convtranspose not supported yet"});
}
if (enable_dml) {