mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Allow 3D ConvTranspose in CUDA execution provider (#6794)
Co-authored-by: Felix Thielke <felix.thielke@mevis.fraunhofer.de>
This commit is contained in:
parent
5a473216b7
commit
71a70ecf6e
2 changed files with 2 additions and 4 deletions
|
|
@ -45,10 +45,9 @@ Status ConvTranspose<T>::DoConvTranspose(OpKernelContext* context, bool dynamic_
|
|||
auto x_data = reinterpret_cast<const CudaT*>(X->template Data<T>());
|
||||
|
||||
auto x_dimensions = X->Shape().NumDimensions();
|
||||
if (x_dimensions != 4 && x_dimensions != 3) {
|
||||
// This condition is not true for test_convtranspose_3d in ONNX tests series.
|
||||
if (x_dimensions < 3 || x_dimensions > 5) {
|
||||
// TODO: the error message should tell which operator raises it.
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input X must be 3- or 4-dimensional.",
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input X must be 3-, 4- or 5-dimensional.",
|
||||
" X: ", X->Shape().ToString().c_str());
|
||||
}
|
||||
const Tensor* W = context->Input<Tensor>(1);
|
||||
|
|
|
|||
|
|
@ -793,7 +793,6 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
|
|||
broken_tests.insert({"tf_inception_v1", "flaky test"}); //TODO: Investigate cause for flakiness
|
||||
broken_tests.insert({"faster_rcnn", "Linux: faster_rcnn:output=6383:shape mismatch, expect {77} got {57}"});
|
||||
broken_tests.insert({"split_zero_size_splits", "alloc failed"});
|
||||
broken_tests.insert({"convtranspose_3d", "3d convtranspose not supported yet"});
|
||||
}
|
||||
|
||||
if (enable_dml) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue