mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
Add ConvTranspose1D (#2578)
This commit is contained in:
parent
79847f39b3
commit
c06dbd8311
5 changed files with 176 additions and 106 deletions
|
|
@ -49,18 +49,19 @@ Status ConvTranspose<T>::DoConvTranspose(OpKernelContext* context, bool dynamic_
|
|||
bool has_bias = dynamic_padding ? num_inputs == 4 : num_inputs == 3;
|
||||
ORT_RETURN_IF_ERROR(conv_transpose_attrs_.PrepareForCompute(context, has_bias, p, dynamic_padding));
|
||||
|
||||
const int64_t input_image_size = p.H * p.W;
|
||||
const int64_t input_image_size = p.input_shape.Size();
|
||||
const int64_t X_offset = p.num_input_channels / conv_transpose_attrs_.group * input_image_size;
|
||||
const int64_t Y_offset = p.Y->Shape().Size() / p.Y->Shape()[0] / conv_transpose_attrs_.group;
|
||||
const int64_t W_offset = p.F->Shape().Size() / conv_transpose_attrs_.group;
|
||||
const int64_t kernel_dim =
|
||||
p.num_output_channels / conv_transpose_attrs_.group * p.kernel_shape[0] * p.kernel_shape[1];
|
||||
const int64_t output_image_size = p.Y->Shape()[2] * p.Y->Shape()[3];
|
||||
const int64_t kernel_size = TensorShape(p.kernel_shape).Size();
|
||||
const int64_t kernel_dim = p.num_output_channels / conv_transpose_attrs_.group * kernel_size;
|
||||
const int64_t output_size = (p.Y->Shape().Slice(2)).Size();
|
||||
|
||||
AllocatorPtr alloc;
|
||||
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc));
|
||||
|
||||
auto col_data = alloc->Alloc(sizeof(T) * kernel_dim * p.H * p.W);
|
||||
const int64_t col_buffer_size = kernel_dim * p.input_shape.Size();
|
||||
auto col_data = alloc->Alloc(sizeof(T) * col_buffer_size);
|
||||
BufferUniquePtr col_buffer(col_data, BufferDeleter(alloc));
|
||||
T* col_buffer_data = static_cast<T*>(col_buffer.get());
|
||||
|
||||
|
|
@ -68,53 +69,103 @@ Status ConvTranspose<T>::DoConvTranspose(OpKernelContext* context, bool dynamic_
|
|||
const T* filter_data = p.F->template Data<T>();
|
||||
T* Ydata = p.Y->template MutableData<T>();
|
||||
|
||||
for (auto image_id = 0; image_id < p.N; ++image_id) {
|
||||
for (int group_id = 0; group_id < conv_transpose_attrs_.group; ++group_id) {
|
||||
// Weight term
|
||||
math::Gemm<T>(
|
||||
CblasTrans,
|
||||
CblasNoTrans,
|
||||
kernel_dim,
|
||||
input_image_size,
|
||||
p.num_input_channels / conv_transpose_attrs_.group,
|
||||
1,
|
||||
filter_data + group_id * W_offset,
|
||||
Xdata + group_id * X_offset,
|
||||
0,
|
||||
col_buffer_data,
|
||||
tp);
|
||||
std::vector<int64_t> col_buffer_shape{kernel_dim};
|
||||
col_buffer_shape.insert(col_buffer_shape.end(), p.input_shape.GetDims().begin(), p.input_shape.GetDims().end());
|
||||
|
||||
// Col2im
|
||||
math::Col2im<T, CPUMathUtil, StorageOrder::NCHW>(
|
||||
col_buffer_data,
|
||||
p.num_output_channels / conv_transpose_attrs_.group,
|
||||
p.Y->Shape()[2],
|
||||
p.Y->Shape()[3],
|
||||
p.kernel_shape[0],
|
||||
p.kernel_shape[1],
|
||||
p.dilations[0],
|
||||
p.dilations[1],
|
||||
p.pads[0],
|
||||
p.pads[1],
|
||||
p.pads[2],
|
||||
p.pads[3],
|
||||
p.strides[0],
|
||||
p.strides[1],
|
||||
Ydata + group_id * Y_offset,
|
||||
&CPUMathUtil::Instance());
|
||||
if (p.X->Shape().NumDimensions() == 4) {
|
||||
for (auto image_id = 0; image_id < p.N; ++image_id) {
|
||||
for (int group_id = 0; group_id < conv_transpose_attrs_.group; ++group_id) {
|
||||
// Weight term
|
||||
math::Gemm<T>(
|
||||
CblasTrans,
|
||||
CblasNoTrans,
|
||||
kernel_dim,
|
||||
input_image_size,
|
||||
p.num_input_channels / conv_transpose_attrs_.group,
|
||||
1,
|
||||
filter_data + group_id * W_offset,
|
||||
Xdata + group_id * X_offset,
|
||||
0,
|
||||
col_buffer_data,
|
||||
tp);
|
||||
|
||||
// Col2im
|
||||
math::Col2im<T, CPUMathUtil, StorageOrder::NCHW>(
|
||||
col_buffer_data,
|
||||
p.num_output_channels / conv_transpose_attrs_.group,
|
||||
p.Y->Shape()[2],
|
||||
p.Y->Shape()[3],
|
||||
p.kernel_shape[0],
|
||||
p.kernel_shape[1],
|
||||
p.dilations[0],
|
||||
p.dilations[1],
|
||||
p.pads[0],
|
||||
p.pads[1],
|
||||
p.pads[2],
|
||||
p.pads[3],
|
||||
p.strides[0],
|
||||
p.strides[1],
|
||||
Ydata + group_id * Y_offset,
|
||||
&CPUMathUtil::Instance());
|
||||
}
|
||||
|
||||
if (p.B != nullptr) {
|
||||
auto Ymatrix = EigenMatrixMap<T>(Ydata, output_size, p.num_output_channels);
|
||||
auto Bvec = ConstEigenVectorMap<T>(p.B->template Data<T>(), p.num_output_channels);
|
||||
Ymatrix.rowwise() += Bvec.transpose();
|
||||
}
|
||||
|
||||
Xdata += X_offset * conv_transpose_attrs_.group;
|
||||
Ydata += Y_offset * conv_transpose_attrs_.group;
|
||||
}
|
||||
} else {
|
||||
TensorShape output_shape = p.Y->Shape().Slice(1);
|
||||
output_shape[0] = output_shape[0] / conv_transpose_attrs_.group;
|
||||
|
||||
if (p.B != nullptr) {
|
||||
auto Ymatrix = EigenMatrixMap<T>(Ydata, output_image_size, p.num_output_channels);
|
||||
auto Bvec = ConstEigenVectorMap<T>(p.B->template Data<T>(), p.num_output_channels);
|
||||
Ymatrix.rowwise() += Bvec.transpose();
|
||||
for (auto image_id = 0; image_id < p.N; ++image_id) {
|
||||
for (int group_id = 0; group_id < conv_transpose_attrs_.group; ++group_id) {
|
||||
// Weight term
|
||||
math::Gemm<T>(
|
||||
CblasTrans,
|
||||
CblasNoTrans,
|
||||
kernel_dim,
|
||||
input_image_size,
|
||||
p.num_input_channels / conv_transpose_attrs_.group,
|
||||
1,
|
||||
filter_data + group_id * W_offset,
|
||||
Xdata + group_id * X_offset,
|
||||
0,
|
||||
col_buffer_data,
|
||||
tp);
|
||||
|
||||
// Col2im
|
||||
math::Col2imNd<T, CPUMathUtil, StorageOrder::NCHW>(
|
||||
col_buffer_data,
|
||||
output_shape.GetDims().data(),
|
||||
col_buffer_shape.data(),
|
||||
output_shape.Size(),
|
||||
col_buffer_size,
|
||||
p.kernel_shape.data(),
|
||||
p.strides.data(),
|
||||
p.dilations.data(),
|
||||
p.pads.data(),
|
||||
static_cast<int>(p.kernel_shape.size()),
|
||||
Ydata + group_id * Y_offset,
|
||||
&CPUMathUtil::Instance());
|
||||
}
|
||||
|
||||
if (p.B != nullptr) {
|
||||
|
||||
auto Ymatrix = EigenMatrixMap<T>(Ydata, output_size, p.num_output_channels);
|
||||
auto Bvec = ConstEigenVectorMap<T>(p.B->template Data<T>(), p.num_output_channels);
|
||||
Ymatrix.rowwise() += Bvec.transpose();
|
||||
}
|
||||
|
||||
Xdata += X_offset * conv_transpose_attrs_.group;
|
||||
Ydata += Y_offset * conv_transpose_attrs_.group;
|
||||
}
|
||||
|
||||
Xdata += X_offset * conv_transpose_attrs_.group;
|
||||
Ydata += Y_offset * conv_transpose_attrs_.group;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -34,10 +34,9 @@ struct ConvTransposeAttributes : public ConvAttributes {
|
|||
const Tensor* B;
|
||||
Tensor* Y;
|
||||
int64_t N;
|
||||
int64_t H;
|
||||
int64_t W;
|
||||
int64_t num_input_channels;
|
||||
int64_t num_output_channels;
|
||||
TensorShape input_shape;
|
||||
std::vector<int64_t> kernel_shape;
|
||||
std::vector<int64_t> pads;
|
||||
std::vector<int64_t> dilations;
|
||||
|
|
@ -49,7 +48,12 @@ struct ConvTransposeAttributes : public ConvAttributes {
|
|||
const Tensor* F = context->Input<Tensor>(1);
|
||||
const Tensor* Pads = dynamic_padding ? context->Input<Tensor>(2) : nullptr;
|
||||
const Tensor* B = has_bias ? (dynamic_padding ? context->Input<Tensor>(3) : context->Input<Tensor>(2)) : nullptr;
|
||||
const TensorShape& input_shape = X->Shape();
|
||||
const TensorShape& input_shape = X->Shape().Slice(2);
|
||||
|
||||
const int64_t num_input_channels = X->Shape()[1];
|
||||
const int64_t N = X->Shape()[0];
|
||||
const int64_t num_output_channels_multiplier = F->Shape()[1];
|
||||
const int64_t num_output_channels = num_output_channels_multiplier * group;
|
||||
|
||||
// input validations
|
||||
if (group <= 0) {
|
||||
|
|
@ -57,34 +61,26 @@ struct ConvTransposeAttributes : public ConvAttributes {
|
|||
" group: ", group);
|
||||
}
|
||||
|
||||
if (input_shape.NumDimensions() != 4) {
|
||||
// This condition is not true for two tests in ONNX tests series:
|
||||
// test_convtranspose_1d_cpu, test_convtranspose_3d_cpu.
|
||||
if (X->Shape().NumDimensions() > 4) {
|
||||
// This condition is not true for 1 test in ONNX tests series:
|
||||
// test_convtranspose_3d_cpu.
|
||||
// TODO: the error message should tell which operator raises it.
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input X must be 4-dimensional.",
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Only 1D and 2D ConvTranspose is supported.",
|
||||
" X: ", X->Shape().ToString().c_str());
|
||||
}
|
||||
|
||||
if (input_shape.NumDimensions() != F->Shape().NumDimensions()) {
|
||||
if (X->Shape().NumDimensions() != F->Shape().NumDimensions()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "X num_dims does not match W num_dims.",
|
||||
" X: ", X->Shape().ToString().c_str(),
|
||||
" W: ", F->Shape().ToString().c_str());
|
||||
}
|
||||
|
||||
const int64_t num_input_channels = input_shape[1];
|
||||
|
||||
if (F->Shape()[0] != num_input_channels) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "filter number not equal to input channel number.",
|
||||
" filter_number: ", F->Shape()[0],
|
||||
" num_input_channels: ", num_input_channels);
|
||||
}
|
||||
|
||||
const int64_t N = input_shape[0];
|
||||
const int64_t H = input_shape[2];
|
||||
const int64_t W = input_shape[3];
|
||||
const int64_t num_output_channels_multiplier = F->Shape()[1];
|
||||
const int64_t num_output_channels = num_output_channels_multiplier * group;
|
||||
|
||||
// it looks like num_output_channels is really k*group similar to how in the conv case
|
||||
// num_input_channels is k*group. hence removing the check for num_output_channels here.
|
||||
|
||||
|
|
@ -102,7 +98,7 @@ struct ConvTransposeAttributes : public ConvAttributes {
|
|||
local_output_padding.resize(kernel_shape.size(), 0);
|
||||
}
|
||||
std::vector<int64_t> local_pads;
|
||||
local_pads.reserve(2 * (input_shape.NumDimensions() - 2));
|
||||
local_pads.reserve(2 * (input_shape.NumDimensions()));
|
||||
if (dynamic_padding) {
|
||||
for (int64_t i = 0; i < Pads->Shape().SizeFromDimension(0); ++i) {
|
||||
local_pads.push_back(Pads->Data<int64_t>()[i]);
|
||||
|
|
@ -125,7 +121,7 @@ struct ConvTransposeAttributes : public ConvAttributes {
|
|||
std::vector<int64_t> Y_dims;
|
||||
|
||||
ComputePadsAndOutputShape(input_shape, num_output_channels, kernel_shape,
|
||||
local_strides, local_dilations, local_output_padding, &local_pads, &Y_dims);
|
||||
local_strides, local_dilations, local_output_padding, N, &local_pads, &Y_dims);
|
||||
TensorShape Yshape(Y_dims);
|
||||
Tensor* Y = context->Output(0, Yshape);
|
||||
|
||||
|
|
@ -134,8 +130,7 @@ struct ConvTransposeAttributes : public ConvAttributes {
|
|||
p.B = B;
|
||||
p.Y = Y;
|
||||
p.N = N;
|
||||
p.H = H;
|
||||
p.W = W;
|
||||
p.input_shape = std::move(input_shape);
|
||||
p.num_input_channels = num_input_channels;
|
||||
p.num_output_channels = num_output_channels;
|
||||
p.kernel_shape = std::move(kernel_shape);
|
||||
|
|
@ -147,52 +142,40 @@ struct ConvTransposeAttributes : public ConvAttributes {
|
|||
|
||||
void ComputePadsAndOutputShape(TensorShape input_shape, int64_t output_channel,
|
||||
const std::vector<int64_t>& kernel_shape, const std::vector<int64_t>& p_strides,
|
||||
const std::vector<int64_t>& p_dilations, const std::vector<int64_t>& p_output_padding,
|
||||
const std::vector<int64_t>& p_dilations, const std::vector<int64_t>& p_output_padding, const int64_t N,
|
||||
std::vector<int64_t>* p_pads, std::vector<int64_t>* output_shape_p) const {
|
||||
const int64_t N = input_shape[0];
|
||||
const int64_t H = input_shape[2];
|
||||
const int64_t W = input_shape[3];
|
||||
int64_t output_height = -1;
|
||||
int64_t output_width = -1;
|
||||
size_t output_shape_size = output_shape.size();
|
||||
output_shape_p->insert(output_shape_p->begin(), {N, output_channel});
|
||||
|
||||
if (output_shape_size != 0) {
|
||||
output_height = output_shape[output_shape_size - 2];
|
||||
output_width = output_shape[output_shape_size - 1];
|
||||
ORT_ENFORCE(output_height >= H, "Output height cannot be smaller than input height.");
|
||||
ORT_ENFORCE(output_width >= W, "Output width cannot be smaller than input width.");
|
||||
size_t rank = input_shape.NumDimensions();
|
||||
for (size_t dim = 0; dim < rank; ++dim) {
|
||||
int64_t dim_size = -1;
|
||||
|
||||
if (output_shape_size != 0) {
|
||||
dim_size = output_shape_size == rank ? output_shape[dim] : output_shape[dim + 2];
|
||||
}
|
||||
|
||||
ComputeTransposePadAndOutputShape(
|
||||
input_shape[dim],
|
||||
p_strides[dim],
|
||||
kernel_shape[dim],
|
||||
p_dilations[dim],
|
||||
p_output_padding[dim],
|
||||
auto_pad,
|
||||
&p_pads->at(dim),
|
||||
&p_pads->at(input_shape.NumDimensions() + dim),
|
||||
&dim_size);
|
||||
|
||||
ORT_ENFORCE(dim_size > 0, "Invalid input shape: ", input_shape.ToString());
|
||||
output_shape_p->push_back(dim_size);
|
||||
}
|
||||
|
||||
ComputeTransposePadAndOutputShape(
|
||||
H,
|
||||
p_strides[0],
|
||||
kernel_shape[0],
|
||||
p_dilations[0],
|
||||
p_output_padding[0],
|
||||
auto_pad,
|
||||
&p_pads->at(0),
|
||||
&p_pads->at(2),
|
||||
&output_height);
|
||||
|
||||
ComputeTransposePadAndOutputShape(
|
||||
W,
|
||||
p_strides[1],
|
||||
kernel_shape[1],
|
||||
p_dilations[1],
|
||||
p_output_padding[1],
|
||||
auto_pad,
|
||||
&p_pads->at(1),
|
||||
&p_pads->at(3),
|
||||
&output_width);
|
||||
|
||||
output_shape_p->insert(output_shape_p->begin(), {N, output_channel, output_height, output_width});
|
||||
}
|
||||
|
||||
const std::vector<int64_t> output_padding;
|
||||
const std::vector<int64_t> output_shape;
|
||||
|
||||
private:
|
||||
void ComputeTransposePadAndOutputShape (
|
||||
private:
|
||||
void ComputeTransposePadAndOutputShape(
|
||||
const int64_t in_size,
|
||||
const int64_t stride,
|
||||
const int64_t kernel,
|
||||
|
|
|
|||
|
|
@ -44,6 +44,13 @@ Status ConvTranspose<T>::DoConvTranspose(OpKernelContext* context, bool dynamic_
|
|||
const auto& x_dims = x_shape.GetDims();
|
||||
auto x_data = reinterpret_cast<const CudaT*>(X->template Data<T>());
|
||||
|
||||
if (X->Shape().NumDimensions() != 4) {
|
||||
// This condition is not true for two tests in ONNX tests series:
|
||||
// test_convtranspose_1d_cpu, test_convtranspose_3d_cpu.
|
||||
// TODO: the error message should tell which operator raises it.
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input X must be 4-dimensional.",
|
||||
" X: ", X->Shape().ToString().c_str());
|
||||
}
|
||||
const Tensor* W = context->Input<Tensor>(1);
|
||||
const TensorShape& w_shape = W->Shape();
|
||||
std::vector<int64_t> w_dims = w_shape.GetDims();
|
||||
|
|
|
|||
|
|
@ -425,7 +425,6 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
|
|||
{"BERT_Squad", "test data bug"},
|
||||
{"constantofshape_float_ones", "test data bug", {"onnx141","onnx150"}},
|
||||
{"constantofshape_int_zeros", "test data bug", {"onnx141","onnx150"}},
|
||||
{"convtranspose_1d", "1d convtranspose not supported yet"},
|
||||
{"convtranspose_3d", "3d convtranspose not supported yet"},
|
||||
{"cast_STRING_to_FLOAT", "Linux CI has old ONNX python package with bad test data", {"onnx141"}},
|
||||
// Numpy float to string has unexpected rounding for some results given numpy default precision is meant to be 8.
|
||||
|
|
@ -472,6 +471,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
|
|||
broken_tests.insert({"BERT_Squad", "Invalid Feed Input Name:input4"});
|
||||
broken_tests.insert({"mask_rcnn_keras", "Results mismatch: 8 of 81000"});
|
||||
broken_tests.insert({"candy", "Results mismatch: 2 of 150528"});
|
||||
broken_tests.insert({"convtranspose_1d", "1d convtranspose not supported yet"});
|
||||
#endif
|
||||
|
||||
#ifdef USE_DNNL
|
||||
|
|
@ -486,6 +486,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
|
|||
broken_tests.insert({"maxpool_2d_ceil", "maxpool ceiling not supported"});
|
||||
broken_tests.insert({"maxpool_2d_dilations", "maxpool dilations not supported"});
|
||||
broken_tests.insert({"mlperf_ssd_resnet34_1200", "test pass on dev box but fails on CI build"});
|
||||
broken_tests.insert({"convtranspose_1d", "1d convtranspose not supported yet"});
|
||||
#endif
|
||||
|
||||
#ifdef USE_OPENVINO
|
||||
|
|
@ -494,6 +495,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
|
|||
broken_tests.insert({"fp16_tiny_yolov2", "accuaracy mismatch with fp16 precision"});
|
||||
broken_tests.insert({"scan_sum", "disable temporarily"});
|
||||
broken_tests.insert({"scan9_sum", "disable temporarily"});
|
||||
broken_tests.insert({"convtranspose_1d", "1d convtranspose not supported yet"});
|
||||
#ifdef OPENVINO_CONFIG_GPU_FP32
|
||||
broken_tests.insert({"tiny_yolov2", "accuracy mismatch"});
|
||||
broken_tests.insert({"div", "will be fixed in the next release"});
|
||||
|
|
@ -513,6 +515,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
|
|||
broken_tests.insert({"gemm_transposeB", "Temporarily disabled pending investigation"});
|
||||
broken_tests.insert({"range_float_type_positive_delta_expanded", "Temporarily disabled pending investigation"});
|
||||
broken_tests.insert({"range_int32_type_negative_delta_expanded", "Temporarily disabled pending investigation"});
|
||||
broken_tests.insert({"convtranspose_1d", "1d convtranspose not supported yet"});
|
||||
#endif
|
||||
|
||||
#ifdef USE_TENSORRT
|
||||
|
|
@ -529,6 +532,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
|
|||
broken_tests.insert({"tf_resnet_v2_101", "TRT Engine couldn't be created"});
|
||||
broken_tests.insert({"tf_resnet_v2_152", "TRT Engine couldn't be created"});
|
||||
broken_tests.insert({"tf_resnet_v2_50", "TRT Engine couldn't be created"});
|
||||
broken_tests.insert({"convtranspose_1d", "1d convtranspose not supported yet"});
|
||||
#endif
|
||||
|
||||
#ifdef USE_CUDA
|
||||
|
|
@ -537,6 +541,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
|
|||
broken_tests.insert({"mlperf_ssd_mobilenet_300", "unknown error"});
|
||||
broken_tests.insert({"mlperf_ssd_resnet34_1200", "unknown error"});
|
||||
broken_tests.insert({"tf_inception_v1", "flaky test"}); //TODO: Investigate cause for flakiness
|
||||
broken_tests.insert({"convtranspose_1d", "1d convtranspose not supported yet"});
|
||||
#endif
|
||||
|
||||
#ifdef USE_DML
|
||||
|
|
@ -547,6 +552,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
|
|||
broken_tests.insert({"resize_downsample_linear", "ORT 0.4 uses asymmetric but will conform to half_pixel in the next ONNX version."});
|
||||
broken_tests.insert({"resize_upsample_linear", "ORT 0.4 uses asymmetric but will conform to half_pixel in the next ONNX version."});
|
||||
broken_tests.insert({"resize_upsample_linear", "ORT 0.4 uses asymmetric but will conform to half_pixel in the next ONNX version."});
|
||||
broken_tests.insert({"convtranspose_1d", "1d convtranspose not supported yet"});
|
||||
|
||||
// These tests are temporarily disabled pending a fix to the DML EP for handling of the output_padding attribute
|
||||
broken_tests.insert({"ConvTranspose2d", "Temporarily disabled due to EP bug"});
|
||||
|
|
|
|||
|
|
@ -25,7 +25,8 @@ void TestConvTransposeOp(const ConvTransposeOpAttributes& attributes,
|
|||
const std::initializer_list<float>& expected_output,
|
||||
const vector<int64_t>& expected_output_shape,
|
||||
OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess,
|
||||
const std::string& err_str = "") {
|
||||
const std::string& err_str = "",
|
||||
const std::unordered_set<std::string>& excluded_provider_types = {kTensorrtExecutionProvider}) {
|
||||
OpTester test("ConvTranspose");
|
||||
test.AddAttribute("kernel_shape", attributes.kernel_shape);
|
||||
test.AddAttribute("pads", attributes.pads);
|
||||
|
|
@ -52,10 +53,32 @@ void TestConvTransposeOp(const ConvTransposeOpAttributes& attributes,
|
|||
test.AddInput<float>(szNames[i], input_shapes[i], inputs[i]);
|
||||
}
|
||||
test.AddOutput<float>("Y", expected_output_shape, expected_output);
|
||||
test.Run(expect_result, err_str, {kTensorrtExecutionProvider}); // Disable TensorRT because weight as input is not supported
|
||||
|
||||
|
||||
test.Run(expect_result, err_str, excluded_provider_types); // Disable TensorRT because weight as input is not supported
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST(ConvTransposeTest, ConvTranspose_1D) {
|
||||
ConvTransposeOpAttributes attrs = {
|
||||
vector<int64_t>{3}, // kernel_shape
|
||||
{}, // output_padding
|
||||
{}, // output_shape
|
||||
vector<int64_t>{0, 0}, // pads
|
||||
vector<int64_t>{1}, // strides
|
||||
vector<int64_t>{1}, // dilations
|
||||
1 // group
|
||||
};
|
||||
vector<float> X = {0.0f, 1.0f, 2.0f};
|
||||
vector<int64_t> X_shape = {1, 1, 3};
|
||||
vector<float> W = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
|
||||
vector<int64_t> W_shape = {1, 2, 3};
|
||||
vector<int64_t> Y_shape = {1, 2, 5};
|
||||
auto expected_vals = {0.0f, 1.0f, 3.0f, 3.0f, 2.0f, 0.0f, 1.0f, 3.0f, 3.0f, 2.0f};
|
||||
|
||||
TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(ConvTransposeTest, ConvTranspose_2D) {
|
||||
ConvTransposeOpAttributes attrs = {
|
||||
vector<int64_t>{3, 3}, // kernel_shape
|
||||
|
|
|
|||
Loading…
Reference in a new issue